diff --git a/t.cc b/t.cc index 325b207..a93fa46 100644 --- a/t.cc +++ b/t.cc @@ -3183,7 +3183,7 @@ cin>>d>>n; cout < d({6,2,1,4,3,5}); d.copyonwrite(); @@ -3191,4 +3191,19 @@ netsort(d.size(),&d[0]); cout < epsilon(g); +cout < namespace LA { +LA_largeindex subindex(int *sign, const INDEXGROUP &g, const NRVec &I) //index of one subgroup +{ +#ifdef DEBUG +if(I.size()<=0) laerror("empty index group in subindex"); +if(g.number!=I.size()) laerror("mismatch in the number of indices in a group"); +for(int i=0; i= g.offset+g.range) laerror("index out of range in tensor subindex"); +#endif + +switch(I.size()) //a few special cases for efficiency + { + case 0: + *sign=0; + return 0; + break; + case 1: + *sign=1; + return I[0]-g.offset; + break; + case 2: + { + *sign=1; + if(g.symmetry==0) return (I[1]-g.offset)*g.range+I[0]-g.offset; + LA_index i0,i1; + if(I[0]>I[1]) {i1=I[0]; i0=I[1]; if(g.symmetry<0) *sign = -1;} else {i1=I[1]; i0=I[0];} + i0 -= g.offset; + i1 -= g.offset; + if(g.symmetry<0) + { + if(i0==i1) {*sign=0; return -1;} + return i1*(i1-1)/2+i0; + } + else + { + return i1*(i1+1)/2+i0; + } + } + break; + + default: //general case + { + *sign=1; + + if(g.symmetry==0) //rectangular case + { + LA_largeindex r=0; + for(int i=I.size()-1; i>=0; --i) + { + r*= g.range; + r+= I[i]-g.offset; + } + return r; + } + } + + //compressed storage case + NRVec II(I); + II.copyonwrite(); + II -= g.offset; + int parity=netsort(II.size(),&II[0]); + if(g.symmetry<0 && (parity&1)) *sign= -1; + if(g.symmetry<0) //antisymmetric + { + for(int i=0; i -LA_largeindex Tensor::index(int *sign, const SUPERINDEX &I) +LA_largeindex Tensor::index(int *sign, const SUPERINDEX &I) const { //check index structure and ranges #ifdef DEBUG @@ -45,11 +131,27 @@ for(int i=0; i; +template class Tensor >; }//namespace diff --git a/tensor.h b/tensor.h index fbbb84e..8a0fcbe 100644 --- a/tensor.h +++ b/tensor.h @@ -42,7 +42,8 @@ class Signedpointer T *ptr; int sgn; public: - Signedpointer(const T *p, int s) : ptr(p),sgn(s) {}; + Signedpointer(T *p, int s) : ptr(p),sgn(s) {}; + T& operator=(const T rhs) {if(sgn==0) return *ptr; if(sgn>0) *ptr=rhs; else *ptr = -rhs; return *ptr;} //@@@@@@operations on singedpointer as LHS of the non-const tensor.operator() expressions }; @@ -51,12 +52,22 @@ typedef int LA_index; typedef int LA_largeindex; typedef class indexgroup { +public: int number; //number of indices int symmetry; //-1 0 or 1 LA_index offset; //indices start at LA_index range; //indices span this range } INDEXGROUP; +template<> +class LA_traits { + public: + static bool is_plaindata() {return true;}; + static void copyonwrite(indexgroup& x) {}; + typedef INDEXGROUP normtype; +}; + + typedef NRVec FLATINDEX; //all indices but in a single vector typedef NRVec > SUPERINDEX; //all indices in the INDEXGROUP structure @@ -71,22 +82,29 @@ class Tensor { NRVec cumsizes; //cumulative sizes of symmetry index groups private: - LA_largeindex index(int *sign, const SUPERINDEX &I); //map the tensor indices to the position in data - LA_largeindex index(int *sign, const FLATINDEX &I); //map the tensor indices to the position in data - LA_largeindex vindex(int *sign, int i1, va_list args); //map list of indices to the position in data @@@must call va_end + LA_largeindex index(int *sign, const SUPERINDEX &I) const; //map the tensor indices to the position in data + LA_largeindex index(int *sign, const FLATINDEX &I) const; //map the tensor indices to the position in data + LA_largeindex vindex(int *sign, int i1, va_list args) const; //map list of indices to the position in data @@@must call va_end public: + //constructors Tensor() {}; - Tensor(const NRVec &s) : shape(s), data((int)getsize()) {data.clear();}; + Tensor(const NRVec &s) : shape(s), data((int)getsize()) {data.clear();}; //general tensor + Tensor(const indexgroup &g) {shape.resize(1); shape[0]=g; data.resize(getsize()); data.clear();}; //tensor with a single index group + int getrank() const; //is computed from shape LA_largeindex getsize(); //set redundant data and return total size + LA_largeindex size() const {return data.size();}; void copyonwrite() {shape.copyonwrite(); data.copyonwrite();}; - inline Signedpointer operator()(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer(&data[i],sign);}; - inline const T& operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];}; - inline Signedpointer operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer(&data[i],sign);}; - inline const T& operator()(const FLATINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];}; - inline Signedpointer operator()(int i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer(&data[i],sign); }; - inline const T& operator()(int i1...) const {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];}; + inline Signedpointer operator[](const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer(&data[i],sign);}; + inline T operator()(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];}; + inline Signedpointer operator[](const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer(&data[i],sign);}; + inline T operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];}; + //inline Signedpointer operator[](int i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer(&data[i],sign); }; + //cannot have operator[] with variable number of argmuments + inline T operator()(int i1...) {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];}; + //@@@do a 'set' operatiaon with va_arg instead + //NOTE: for sign==0 data[i] can be undefined pointer, avoid dereferencing it //@@@TODO - unwinding to full size in a specified index