working on tensor class

This commit is contained in:
2024-04-03 18:43:55 +02:00
parent c67549a296
commit 80f915946b
3 changed files with 150 additions and 15 deletions

View File

@@ -42,7 +42,8 @@ class Signedpointer
T *ptr;
int sgn;
public:
Signedpointer(const T *p, int s) : ptr(p),sgn(s) {};
Signedpointer(T *p, int s) : ptr(p),sgn(s) {};
T& operator=(const T rhs) {if(sgn==0) return *ptr; if(sgn>0) *ptr=rhs; else *ptr = -rhs; return *ptr;}
//@@@@@@operations on singedpointer as LHS of the non-const tensor.operator() expressions
};
@@ -51,12 +52,22 @@ typedef int LA_index;
typedef int LA_largeindex;
typedef class indexgroup {
public:
int number; //number of indices
int symmetry; //-1 0 or 1
LA_index offset; //indices start at
LA_index range; //indices span this range
} INDEXGROUP;
template<>
class LA_traits<indexgroup> {
public:
static bool is_plaindata() {return true;};
static void copyonwrite(indexgroup& x) {};
typedef INDEXGROUP normtype;
};
typedef NRVec<LA_index> FLATINDEX; //all indices but in a single vector
typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP structure
@@ -71,22 +82,29 @@ class Tensor {
NRVec<LA_largeindex> cumsizes; //cumulative sizes of symmetry index groups
private:
LA_largeindex index(int *sign, const SUPERINDEX &I); //map the tensor indices to the position in data
LA_largeindex index(int *sign, const FLATINDEX &I); //map the tensor indices to the position in data
LA_largeindex vindex(int *sign, int i1, va_list args); //map list of indices to the position in data @@@must call va_end
LA_largeindex index(int *sign, const SUPERINDEX &I) const; //map the tensor indices to the position in data
LA_largeindex index(int *sign, const FLATINDEX &I) const; //map the tensor indices to the position in data
LA_largeindex vindex(int *sign, int i1, va_list args) const; //map list of indices to the position in data @@@must call va_end
public:
//constructors
Tensor() {};
Tensor(const NRVec<indexgroup> &s) : shape(s), data((int)getsize()) {data.clear();};
Tensor(const NRVec<indexgroup> &s) : shape(s), data((int)getsize()) {data.clear();}; //general tensor
Tensor(const indexgroup &g) {shape.resize(1); shape[0]=g; data.resize(getsize()); data.clear();}; //tensor with a single index group
int getrank() const; //is computed from shape
LA_largeindex getsize(); //set redundant data and return total size
LA_largeindex size() const {return data.size();};
void copyonwrite() {shape.copyonwrite(); data.copyonwrite();};
inline Signedpointer<T> operator()(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
inline const T& operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
inline Signedpointer<T> operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
inline const T& operator()(const FLATINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
inline Signedpointer<T> operator()(int i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
inline const T& operator()(int i1...) const {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
inline Signedpointer<T> operator[](const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
inline T operator()(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
inline Signedpointer<T> operator[](const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
inline T operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
//inline Signedpointer<T> operator[](int i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
//cannot have operator[] with variable number of argmuments
inline T operator()(int i1...) {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
//@@@do a 'set' operatiaon with va_arg instead
//NOTE: for sign==0 data[i] can be undefined pointer, avoid dereferencing it
//@@@TODO - unwinding to full size in a specified index