working on tensor
This commit is contained in:
parent
baee11489b
commit
87dd0c5b65
21
tensor.cc
21
tensor.cc
@ -158,6 +158,27 @@ LA_largeindex Tensor<T>::vindex(int *sign, int i1, va_list args) const
|
||||
}
|
||||
|
||||
|
||||
//binary I/O
|
||||
|
||||
template<typename T>
|
||||
void Tensor<T>::put(int fd) const
|
||||
{
|
||||
shape.put(fd,true);
|
||||
cumsizes.put(fd,true);
|
||||
data.put(fd,true);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Tensor<T>::get(int fd)
|
||||
{
|
||||
shape.get(fd,true);
|
||||
cumsizes.get(fd,true);
|
||||
data.get(fd,true);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template class Tensor<double>;
|
||||
template class Tensor<std::complex<double> >;
|
||||
|
||||
|
28
tensor.h
28
tensor.h
@ -43,8 +43,12 @@ T *ptr;
|
||||
int sgn;
|
||||
public:
|
||||
Signedpointer(T *p, int s) : ptr(p),sgn(s) {};
|
||||
T& operator=(const T rhs) {if(sgn==0) return *ptr; if(sgn>0) *ptr=rhs; else *ptr = -rhs; return *ptr;}
|
||||
//@@@@@@operations like += etc. on singedpointer as LHS of the non-const tensor.operator() expressions
|
||||
//dereferencing *ptr should intentionally segfault for sgn==0
|
||||
T& operator=(const T rhs) {if(sgn>0) *ptr=rhs; else *ptr = -rhs; return *ptr;}
|
||||
T& operator*=(const T rhs) {*ptr *= rhs; return *ptr;}
|
||||
T& operator/=(const T rhs) {*ptr /= rhs; return *ptr;}
|
||||
T& operator+=(const T rhs) {if(sgn>0) *ptr += rhs; else *ptr -= rhs; return *ptr;}
|
||||
T& operator-=(const T rhs) {if(sgn>0) *ptr -= rhs; else *ptr += rhs; return *ptr;}
|
||||
};
|
||||
|
||||
|
||||
@ -65,6 +69,11 @@ class LA_traits<indexgroup> {
|
||||
static bool is_plaindata() {return true;};
|
||||
static void copyonwrite(indexgroup& x) {};
|
||||
typedef INDEXGROUP normtype;
|
||||
static inline void put(int fd, const indexgroup &x, bool dimensions=1) {if(sizeof(indexgroup)!=write(fd,&x,sizeof(indexgroup))) laerror("write error 1 in indexgroup put"); }
|
||||
static inline void multiput(int nn, int fd, const indexgroup *x, bool dimensions=1) {if(nn*sizeof(indexgroup)!=write(fd,x,nn*sizeof(indexgroup))) laerror("write error 1 in indexgroup multiiput"); }
|
||||
static inline void get(int fd, indexgroup &x, bool dimensions=1) {if(sizeof(indexgroup)!=read(fd,&x,sizeof(indexgroup))) laerror("read error 1 in indexgroup get");}
|
||||
static inline void multiget(int nn, int fd, indexgroup *x, bool dimensions=1) {if(nn*sizeof(indexgroup)!=read(fd,x,nn*sizeof(indexgroup))) laerror("read error 1 in indexgroup get");}
|
||||
|
||||
};
|
||||
|
||||
|
||||
@ -74,13 +83,10 @@ typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP stru
|
||||
|
||||
template<typename T>
|
||||
class Tensor {
|
||||
//essential data
|
||||
NRVec<indexgroup> shape;
|
||||
NRVec<LA_largeindex> cumsizes; //cumulative sizes of symmetry index groups (a function of shape but precomputed for efficiency)
|
||||
NRVec<T> data;
|
||||
|
||||
//redundant data to facilitate efficient indexing
|
||||
NRVec<LA_largeindex> cumsizes; //cumulative sizes of symmetry index groups
|
||||
|
||||
private:
|
||||
LA_largeindex index(int *sign, const SUPERINDEX &I) const; //map the tensor indices to the position in data
|
||||
LA_largeindex index(int *sign, const FLATINDEX &I) const; //map the tensor indices to the position in data
|
||||
@ -102,14 +108,18 @@ public:
|
||||
inline T operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||
inline Signedpointer<T> lhs(int i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
|
||||
inline T operator()(int i1...) {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||
|
||||
//NOTE: for sign==0 data[i] can be undefined pointer, avoid dereferencing it
|
||||
|
||||
inline Tensor& operator*=(const T &a) {data*=a; return *this;};
|
||||
inline Tensor& operator/=(const T &a) {data/=a; return *this;};
|
||||
|
||||
void put(int fd) const;
|
||||
void get(int fd);
|
||||
|
||||
//@@@TODO - unwinding to full size in a specified index
|
||||
//@@@TODO - contractions - basic and efficient
|
||||
//@@@TODO get/put to file, stream i/o
|
||||
//@@@dodelat indexy
|
||||
//@@@ dvojite rekurzivni loopover s callbackem - nebo iterator s funkci next???
|
||||
//@@@ stream i/o na zaklade tohoto
|
||||
};
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user