working on tensor

This commit is contained in:
Jiri Pittner 2024-04-04 12:12:12 +02:00
parent baee11489b
commit 87dd0c5b65
2 changed files with 40 additions and 9 deletions

View File

@ -158,6 +158,27 @@ LA_largeindex Tensor<T>::vindex(int *sign, int i1, va_list args) const
}
//binary I/O
template<typename T>
void Tensor<T>::put(int fd) const
{
shape.put(fd,true);
cumsizes.put(fd,true);
data.put(fd,true);
}
template<typename T>
void Tensor<T>::get(int fd)
{
shape.get(fd,true);
cumsizes.get(fd,true);
data.get(fd,true);
}
template class Tensor<double>;
template class Tensor<std::complex<double> >;

View File

@ -43,8 +43,12 @@ T *ptr;
int sgn;
public:
Signedpointer(T *p, int s) : ptr(p),sgn(s) {};
T& operator=(const T rhs) {if(sgn==0) return *ptr; if(sgn>0) *ptr=rhs; else *ptr = -rhs; return *ptr;}
//@@@@@@operations like += etc. on singedpointer as LHS of the non-const tensor.operator() expressions
//dereferencing *ptr should intentionally segfault for sgn==0
T& operator=(const T rhs) {if(sgn>0) *ptr=rhs; else *ptr = -rhs; return *ptr;}
T& operator*=(const T rhs) {*ptr *= rhs; return *ptr;}
T& operator/=(const T rhs) {*ptr /= rhs; return *ptr;}
T& operator+=(const T rhs) {if(sgn>0) *ptr += rhs; else *ptr -= rhs; return *ptr;}
T& operator-=(const T rhs) {if(sgn>0) *ptr -= rhs; else *ptr += rhs; return *ptr;}
};
@ -65,6 +69,11 @@ class LA_traits<indexgroup> {
static bool is_plaindata() {return true;};
static void copyonwrite(indexgroup& x) {};
typedef INDEXGROUP normtype;
static inline void put(int fd, const indexgroup &x, bool dimensions=1) {if(sizeof(indexgroup)!=write(fd,&x,sizeof(indexgroup))) laerror("write error 1 in indexgroup put"); }
static inline void multiput(int nn, int fd, const indexgroup *x, bool dimensions=1) {if(nn*sizeof(indexgroup)!=write(fd,x,nn*sizeof(indexgroup))) laerror("write error 1 in indexgroup multiiput"); }
static inline void get(int fd, indexgroup &x, bool dimensions=1) {if(sizeof(indexgroup)!=read(fd,&x,sizeof(indexgroup))) laerror("read error 1 in indexgroup get");}
static inline void multiget(int nn, int fd, indexgroup *x, bool dimensions=1) {if(nn*sizeof(indexgroup)!=read(fd,x,nn*sizeof(indexgroup))) laerror("read error 1 in indexgroup get");}
};
@ -74,13 +83,10 @@ typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP stru
template<typename T>
class Tensor {
//essential data
NRVec<indexgroup> shape;
NRVec<LA_largeindex> cumsizes; //cumulative sizes of symmetry index groups (a function of shape but precomputed for efficiency)
NRVec<T> data;
//redundant data to facilitate efficient indexing
NRVec<LA_largeindex> cumsizes; //cumulative sizes of symmetry index groups
private:
LA_largeindex index(int *sign, const SUPERINDEX &I) const; //map the tensor indices to the position in data
LA_largeindex index(int *sign, const FLATINDEX &I) const; //map the tensor indices to the position in data
@ -102,14 +108,18 @@ public:
inline T operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
inline Signedpointer<T> lhs(int i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
inline T operator()(int i1...) {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
//NOTE: for sign==0 data[i] can be undefined pointer, avoid dereferencing it
inline Tensor& operator*=(const T &a) {data*=a; return *this;};
inline Tensor& operator/=(const T &a) {data/=a; return *this;};
void put(int fd) const;
void get(int fd);
//@@@TODO - unwinding to full size in a specified index
//@@@TODO - contractions - basic and efficient
//@@@TODO get/put to file, stream i/o
//@@@dodelat indexy
//@@@ dvojite rekurzivni loopover s callbackem - nebo iterator s funkci next???
//@@@ stream i/o na zaklade tohoto
};