working on tensor
This commit is contained in:
parent
baee11489b
commit
87dd0c5b65
21
tensor.cc
21
tensor.cc
@ -158,6 +158,27 @@ LA_largeindex Tensor<T>::vindex(int *sign, int i1, va_list args) const
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//binary I/O
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void Tensor<T>::put(int fd) const
|
||||||
|
{
|
||||||
|
shape.put(fd,true);
|
||||||
|
cumsizes.put(fd,true);
|
||||||
|
data.put(fd,true);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void Tensor<T>::get(int fd)
|
||||||
|
{
|
||||||
|
shape.get(fd,true);
|
||||||
|
cumsizes.get(fd,true);
|
||||||
|
data.get(fd,true);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template class Tensor<double>;
|
template class Tensor<double>;
|
||||||
template class Tensor<std::complex<double> >;
|
template class Tensor<std::complex<double> >;
|
||||||
|
|
||||||
|
28
tensor.h
28
tensor.h
@ -43,8 +43,12 @@ T *ptr;
|
|||||||
int sgn;
|
int sgn;
|
||||||
public:
|
public:
|
||||||
Signedpointer(T *p, int s) : ptr(p),sgn(s) {};
|
Signedpointer(T *p, int s) : ptr(p),sgn(s) {};
|
||||||
T& operator=(const T rhs) {if(sgn==0) return *ptr; if(sgn>0) *ptr=rhs; else *ptr = -rhs; return *ptr;}
|
//dereferencing *ptr should intentionally segfault for sgn==0
|
||||||
//@@@@@@operations like += etc. on singedpointer as LHS of the non-const tensor.operator() expressions
|
T& operator=(const T rhs) {if(sgn>0) *ptr=rhs; else *ptr = -rhs; return *ptr;}
|
||||||
|
T& operator*=(const T rhs) {*ptr *= rhs; return *ptr;}
|
||||||
|
T& operator/=(const T rhs) {*ptr /= rhs; return *ptr;}
|
||||||
|
T& operator+=(const T rhs) {if(sgn>0) *ptr += rhs; else *ptr -= rhs; return *ptr;}
|
||||||
|
T& operator-=(const T rhs) {if(sgn>0) *ptr -= rhs; else *ptr += rhs; return *ptr;}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -65,6 +69,11 @@ class LA_traits<indexgroup> {
|
|||||||
static bool is_plaindata() {return true;};
|
static bool is_plaindata() {return true;};
|
||||||
static void copyonwrite(indexgroup& x) {};
|
static void copyonwrite(indexgroup& x) {};
|
||||||
typedef INDEXGROUP normtype;
|
typedef INDEXGROUP normtype;
|
||||||
|
static inline void put(int fd, const indexgroup &x, bool dimensions=1) {if(sizeof(indexgroup)!=write(fd,&x,sizeof(indexgroup))) laerror("write error 1 in indexgroup put"); }
|
||||||
|
static inline void multiput(int nn, int fd, const indexgroup *x, bool dimensions=1) {if(nn*sizeof(indexgroup)!=write(fd,x,nn*sizeof(indexgroup))) laerror("write error 1 in indexgroup multiiput"); }
|
||||||
|
static inline void get(int fd, indexgroup &x, bool dimensions=1) {if(sizeof(indexgroup)!=read(fd,&x,sizeof(indexgroup))) laerror("read error 1 in indexgroup get");}
|
||||||
|
static inline void multiget(int nn, int fd, indexgroup *x, bool dimensions=1) {if(nn*sizeof(indexgroup)!=read(fd,x,nn*sizeof(indexgroup))) laerror("read error 1 in indexgroup get");}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -74,13 +83,10 @@ typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP stru
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
class Tensor {
|
class Tensor {
|
||||||
//essential data
|
|
||||||
NRVec<indexgroup> shape;
|
NRVec<indexgroup> shape;
|
||||||
|
NRVec<LA_largeindex> cumsizes; //cumulative sizes of symmetry index groups (a function of shape but precomputed for efficiency)
|
||||||
NRVec<T> data;
|
NRVec<T> data;
|
||||||
|
|
||||||
//redundant data to facilitate efficient indexing
|
|
||||||
NRVec<LA_largeindex> cumsizes; //cumulative sizes of symmetry index groups
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LA_largeindex index(int *sign, const SUPERINDEX &I) const; //map the tensor indices to the position in data
|
LA_largeindex index(int *sign, const SUPERINDEX &I) const; //map the tensor indices to the position in data
|
||||||
LA_largeindex index(int *sign, const FLATINDEX &I) const; //map the tensor indices to the position in data
|
LA_largeindex index(int *sign, const FLATINDEX &I) const; //map the tensor indices to the position in data
|
||||||
@ -102,14 +108,18 @@ public:
|
|||||||
inline T operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
inline T operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||||
inline Signedpointer<T> lhs(int i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
|
inline Signedpointer<T> lhs(int i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
|
||||||
inline T operator()(int i1...) {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
inline T operator()(int i1...) {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||||
|
|
||||||
//NOTE: for sign==0 data[i] can be undefined pointer, avoid dereferencing it
|
|
||||||
|
|
||||||
inline Tensor& operator*=(const T &a) {data*=a; return *this;};
|
inline Tensor& operator*=(const T &a) {data*=a; return *this;};
|
||||||
|
inline Tensor& operator/=(const T &a) {data/=a; return *this;};
|
||||||
|
|
||||||
|
void put(int fd) const;
|
||||||
|
void get(int fd);
|
||||||
|
|
||||||
//@@@TODO - unwinding to full size in a specified index
|
//@@@TODO - unwinding to full size in a specified index
|
||||||
//@@@TODO - contractions - basic and efficient
|
//@@@TODO - contractions - basic and efficient
|
||||||
//@@@TODO get/put to file, stream i/o
|
//@@@dodelat indexy
|
||||||
|
//@@@ dvojite rekurzivni loopover s callbackem - nebo iterator s funkci next???
|
||||||
|
//@@@ stream i/o na zaklade tohoto
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user