working on tensor class
This commit is contained in:
parent
c67549a296
commit
80f915946b
17
t.cc
17
t.cc
@ -3183,7 +3183,7 @@ cin>>d>>n;
|
|||||||
cout <<simplicial(d,n)<<" "<<binom(n+d-1,d)<<endl;
|
cout <<simplicial(d,n)<<" "<<binom(n+d-1,d)<<endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(1)
|
if(0)
|
||||||
{
|
{
|
||||||
NRVec<int> d({6,2,1,4,3,5});
|
NRVec<int> d({6,2,1,4,3,5});
|
||||||
d.copyonwrite();
|
d.copyonwrite();
|
||||||
@ -3191,4 +3191,19 @@ netsort(d.size(),&d[0]);
|
|||||||
cout <<d;
|
cout <<d;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if(1)
|
||||||
|
{
|
||||||
|
INDEXGROUP g;
|
||||||
|
g.number=3;
|
||||||
|
g.symmetry= -1;
|
||||||
|
g.offset=1;
|
||||||
|
g.range=3;
|
||||||
|
|
||||||
|
Tensor<double> epsilon(g);
|
||||||
|
cout <<epsilon.size()<<endl;
|
||||||
|
//cout <<epsilon(3,2,1)<<endl;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
108
tensor.cc
108
tensor.cc
@ -21,12 +21,98 @@
|
|||||||
#include "laerror.h"
|
#include "laerror.h"
|
||||||
#include "qsort.h"
|
#include "qsort.h"
|
||||||
#include "miscfunc.h"
|
#include "miscfunc.h"
|
||||||
|
#include <complex>
|
||||||
|
|
||||||
|
|
||||||
namespace LA {
|
namespace LA {
|
||||||
|
|
||||||
|
LA_largeindex subindex(int *sign, const INDEXGROUP &g, const NRVec<LA_index> &I) //index of one subgroup
|
||||||
|
{
|
||||||
|
#ifdef DEBUG
|
||||||
|
if(I.size()<=0) laerror("empty index group in subindex");
|
||||||
|
if(g.number!=I.size()) laerror("mismatch in the number of indices in a group");
|
||||||
|
for(int i=0; i<I.size(); ++i) if(I[i]<g.offset || I[i] >= g.offset+g.range) laerror("index out of range in tensor subindex");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
switch(I.size()) //a few special cases for efficiency
|
||||||
|
{
|
||||||
|
case 0:
|
||||||
|
*sign=0;
|
||||||
|
return 0;
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
*sign=1;
|
||||||
|
return I[0]-g.offset;
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
{
|
||||||
|
*sign=1;
|
||||||
|
if(g.symmetry==0) return (I[1]-g.offset)*g.range+I[0]-g.offset;
|
||||||
|
LA_index i0,i1;
|
||||||
|
if(I[0]>I[1]) {i1=I[0]; i0=I[1]; if(g.symmetry<0) *sign = -1;} else {i1=I[1]; i0=I[0];}
|
||||||
|
i0 -= g.offset;
|
||||||
|
i1 -= g.offset;
|
||||||
|
if(g.symmetry<0)
|
||||||
|
{
|
||||||
|
if(i0==i1) {*sign=0; return -1;}
|
||||||
|
return i1*(i1-1)/2+i0;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return i1*(i1+1)/2+i0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
default: //general case
|
||||||
|
{
|
||||||
|
*sign=1;
|
||||||
|
|
||||||
|
if(g.symmetry==0) //rectangular case
|
||||||
|
{
|
||||||
|
LA_largeindex r=0;
|
||||||
|
for(int i=I.size()-1; i>=0; --i)
|
||||||
|
{
|
||||||
|
r*= g.range;
|
||||||
|
r+= I[i]-g.offset;
|
||||||
|
}
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//compressed storage case
|
||||||
|
NRVec<LA_index> II(I);
|
||||||
|
II.copyonwrite();
|
||||||
|
II -= g.offset;
|
||||||
|
int parity=netsort(II.size(),&II[0]);
|
||||||
|
if(g.symmetry<0 && (parity&1)) *sign= -1;
|
||||||
|
if(g.symmetry<0) //antisymmetric
|
||||||
|
{
|
||||||
|
for(int i=0; i<I.size()-1; ++i)
|
||||||
|
if(II[i]==II[i+1])
|
||||||
|
{*sign=0; return -1;} //identical indices of antisymmetric tensor
|
||||||
|
|
||||||
|
LA_largeindex r=0;
|
||||||
|
for(int i=0; i<II.size(); ++i) r += simplicial(i+1,II[i]-i);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
else //symmetric
|
||||||
|
{
|
||||||
|
LA_largeindex r=0;
|
||||||
|
for(int i=0; i<II.size(); ++i) r += simplicial(i+1,II[i]);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
laerror("this error should not happen");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
LA_largeindex Tensor<T>::index(int *sign, const SUPERINDEX &I)
|
LA_largeindex Tensor<T>::index(int *sign, const SUPERINDEX &I) const
|
||||||
{
|
{
|
||||||
//check index structure and ranges
|
//check index structure and ranges
|
||||||
#ifdef DEBUG
|
#ifdef DEBUG
|
||||||
@ -45,11 +131,27 @@ for(int i=0; i<I.size(); ++i)
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
LA_largeindex r=0;
|
||||||
//@@@@@@@@@
|
*sign=1;
|
||||||
|
for(int g=0; g<shape.size(); ++g) //loop over index groups
|
||||||
|
{
|
||||||
|
int gsign;
|
||||||
|
LA_largeindex groupindex = subindex(&gsign,shape[g],I[g]);
|
||||||
|
std::cout <<"INDEX TEST group "<<g<<" cumsizes "<< cumsizes[g]<<" groupindex "<<groupindex<<std::endl;
|
||||||
|
*sign *= gsign;
|
||||||
|
if(groupindex == -1) return -1;
|
||||||
|
r += groupindex * cumsizes[g];
|
||||||
|
}
|
||||||
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//@@@@todo flatindex
|
||||||
|
|
||||||
|
//@@@@todo vindex
|
||||||
|
|
||||||
|
|
||||||
|
template class Tensor<double>;
|
||||||
|
template class Tensor<std::complex<double> >;
|
||||||
|
|
||||||
|
|
||||||
}//namespace
|
}//namespace
|
||||||
|
40
tensor.h
40
tensor.h
@ -42,7 +42,8 @@ class Signedpointer
|
|||||||
T *ptr;
|
T *ptr;
|
||||||
int sgn;
|
int sgn;
|
||||||
public:
|
public:
|
||||||
Signedpointer(const T *p, int s) : ptr(p),sgn(s) {};
|
Signedpointer(T *p, int s) : ptr(p),sgn(s) {};
|
||||||
|
T& operator=(const T rhs) {if(sgn==0) return *ptr; if(sgn>0) *ptr=rhs; else *ptr = -rhs; return *ptr;}
|
||||||
//@@@@@@operations on singedpointer as LHS of the non-const tensor.operator() expressions
|
//@@@@@@operations on singedpointer as LHS of the non-const tensor.operator() expressions
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -51,12 +52,22 @@ typedef int LA_index;
|
|||||||
typedef int LA_largeindex;
|
typedef int LA_largeindex;
|
||||||
|
|
||||||
typedef class indexgroup {
|
typedef class indexgroup {
|
||||||
|
public:
|
||||||
int number; //number of indices
|
int number; //number of indices
|
||||||
int symmetry; //-1 0 or 1
|
int symmetry; //-1 0 or 1
|
||||||
LA_index offset; //indices start at
|
LA_index offset; //indices start at
|
||||||
LA_index range; //indices span this range
|
LA_index range; //indices span this range
|
||||||
} INDEXGROUP;
|
} INDEXGROUP;
|
||||||
|
|
||||||
|
template<>
|
||||||
|
class LA_traits<indexgroup> {
|
||||||
|
public:
|
||||||
|
static bool is_plaindata() {return true;};
|
||||||
|
static void copyonwrite(indexgroup& x) {};
|
||||||
|
typedef INDEXGROUP normtype;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
typedef NRVec<LA_index> FLATINDEX; //all indices but in a single vector
|
typedef NRVec<LA_index> FLATINDEX; //all indices but in a single vector
|
||||||
typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP structure
|
typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP structure
|
||||||
|
|
||||||
@ -71,22 +82,29 @@ class Tensor {
|
|||||||
NRVec<LA_largeindex> cumsizes; //cumulative sizes of symmetry index groups
|
NRVec<LA_largeindex> cumsizes; //cumulative sizes of symmetry index groups
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LA_largeindex index(int *sign, const SUPERINDEX &I); //map the tensor indices to the position in data
|
LA_largeindex index(int *sign, const SUPERINDEX &I) const; //map the tensor indices to the position in data
|
||||||
LA_largeindex index(int *sign, const FLATINDEX &I); //map the tensor indices to the position in data
|
LA_largeindex index(int *sign, const FLATINDEX &I) const; //map the tensor indices to the position in data
|
||||||
LA_largeindex vindex(int *sign, int i1, va_list args); //map list of indices to the position in data @@@must call va_end
|
LA_largeindex vindex(int *sign, int i1, va_list args) const; //map list of indices to the position in data @@@must call va_end
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
//constructors
|
||||||
Tensor() {};
|
Tensor() {};
|
||||||
Tensor(const NRVec<indexgroup> &s) : shape(s), data((int)getsize()) {data.clear();};
|
Tensor(const NRVec<indexgroup> &s) : shape(s), data((int)getsize()) {data.clear();}; //general tensor
|
||||||
|
Tensor(const indexgroup &g) {shape.resize(1); shape[0]=g; data.resize(getsize()); data.clear();}; //tensor with a single index group
|
||||||
|
|
||||||
int getrank() const; //is computed from shape
|
int getrank() const; //is computed from shape
|
||||||
LA_largeindex getsize(); //set redundant data and return total size
|
LA_largeindex getsize(); //set redundant data and return total size
|
||||||
|
LA_largeindex size() const {return data.size();};
|
||||||
void copyonwrite() {shape.copyonwrite(); data.copyonwrite();};
|
void copyonwrite() {shape.copyonwrite(); data.copyonwrite();};
|
||||||
inline Signedpointer<T> operator()(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
inline Signedpointer<T> operator[](const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
||||||
inline const T& operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
inline T operator()(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||||
inline Signedpointer<T> operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
inline Signedpointer<T> operator[](const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
||||||
inline const T& operator()(const FLATINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
inline T operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||||
inline Signedpointer<T> operator()(int i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
|
//inline Signedpointer<T> operator[](int i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
|
||||||
inline const T& operator()(int i1...) const {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
//cannot have operator[] with variable number of argmuments
|
||||||
|
inline T operator()(int i1...) {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||||
|
//@@@do a 'set' operatiaon with va_arg instead
|
||||||
|
|
||||||
//NOTE: for sign==0 data[i] can be undefined pointer, avoid dereferencing it
|
//NOTE: for sign==0 data[i] can be undefined pointer, avoid dereferencing it
|
||||||
|
|
||||||
//@@@TODO - unwinding to full size in a specified index
|
//@@@TODO - unwinding to full size in a specified index
|
||||||
|
Loading…
Reference in New Issue
Block a user