tensor class unwind_index
This commit is contained in:
parent
da0b3116f6
commit
5c6cb43c61
25
t.cc
25
t.cc
@ -3243,7 +3243,7 @@ cout <<epsilon.data;
|
|||||||
cout <<epsilon;
|
cout <<epsilon;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(1)
|
if(0)
|
||||||
{
|
{
|
||||||
int n=3;
|
int n=3;
|
||||||
NRVec<INDEXGROUP> s(4);
|
NRVec<INDEXGROUP> s(4);
|
||||||
@ -3270,5 +3270,28 @@ for(int i=0; i<n; ++i)
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if(1)
|
||||||
|
{
|
||||||
|
int n=5;
|
||||||
|
INDEXGROUP g;
|
||||||
|
g.number=4;
|
||||||
|
g.symmetry= -1;
|
||||||
|
g.offset=0;
|
||||||
|
g.range=n;
|
||||||
|
|
||||||
|
Tensor<double> e(g);
|
||||||
|
e.randomize(1.);
|
||||||
|
Tensor<double> eu = e.unwind_index(0,1);
|
||||||
|
|
||||||
|
for(int i=0; i<n; ++i)
|
||||||
|
for(int j=0; j<n; ++j)
|
||||||
|
for(int k=0; k<n; ++k)
|
||||||
|
for(int l=0; l<n; ++l)
|
||||||
|
{
|
||||||
|
if(e(i,j,k,l)!=eu(j,i,k,l)) laerror("error in unwind_index");
|
||||||
|
}
|
||||||
|
cout <<e;
|
||||||
|
cout <<eu;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
90
tensor.cc
90
tensor.cc
@ -521,6 +521,8 @@ loopovergroups(*this,shape.size()-1,&pp,I,callback);
|
|||||||
const NRPerm<int> *help_p;
|
const NRPerm<int> *help_p;
|
||||||
template<typename T>
|
template<typename T>
|
||||||
Tensor<T> *help_t;
|
Tensor<T> *help_t;
|
||||||
|
template<typename T>
|
||||||
|
const Tensor<T> *help_tt;
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void permutecallback(const GROUPINDEX &I, T *v)
|
static void permutecallback(const GROUPINDEX &I, T *v)
|
||||||
@ -551,6 +553,94 @@ return r;
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
FLATINDEX superindex2flat(const SUPERINDEX &I)
|
||||||
|
{
|
||||||
|
int rank=0;
|
||||||
|
for(int i=0; i<I.size(); ++i) rank += I[i].size();
|
||||||
|
FLATINDEX J(rank);
|
||||||
|
int ii=0;
|
||||||
|
for(int i=0; i<I.size(); ++i)
|
||||||
|
{
|
||||||
|
for(int j=0; j<I[i].size(); ++j) J[ii++] = I[i][j];
|
||||||
|
}
|
||||||
|
return J;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static void unwind_callback(const SUPERINDEX &I, T *v)
|
||||||
|
{
|
||||||
|
FLATINDEX J = superindex2flat(I);
|
||||||
|
FLATINDEX JP = J.permuted(*help_p,true);
|
||||||
|
//std::cout <<"TEST unwind_callback: from "<<JP<<" TO "<<J<<std::endl;
|
||||||
|
*v = (*help_tt<T>)(JP); //rhs operator() generates the redundant elements for the unwinded lhs tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
Tensor<T> Tensor<T>::unwind_index(int group, int index) const
|
||||||
|
{
|
||||||
|
if(group<0||group>=shape.size()) laerror("wrong group number in unwind_index");
|
||||||
|
if(index<0||index>=shape[group].number) laerror("wrong index number in unwind_index");
|
||||||
|
if(shape[group].number==1) //single index in the group
|
||||||
|
{
|
||||||
|
if(group==0) return *this; //is already the least significant group
|
||||||
|
NRPerm<int> p(shape.size());
|
||||||
|
p[1]= 1+group;
|
||||||
|
int ii=1;
|
||||||
|
for(int i=2; i<=shape.size(); ++i)
|
||||||
|
{
|
||||||
|
p[i]=ii++;
|
||||||
|
if(ii==1+group) ii++; //skip this
|
||||||
|
}
|
||||||
|
if(!p.is_valid()) laerror("internal error in unwind_index");
|
||||||
|
return permute_index_groups(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
//general case - recalculate the shape and allocate the new tensor
|
||||||
|
NRVec<indexgroup> newshape(shape.size()+1);
|
||||||
|
newshape[0].number=1;
|
||||||
|
newshape[0].symmetry=0;
|
||||||
|
newshape[0].range=shape[group].range;
|
||||||
|
#ifndef LA_TENSOR_ZERO_OFFSET
|
||||||
|
newshape[0].offset = shape[group].offset;
|
||||||
|
#endif
|
||||||
|
int flatindex=0; //(group,index) in flat form
|
||||||
|
for(int i=0; i<shape.size(); ++i)
|
||||||
|
{
|
||||||
|
newshape[i+1] = shape[i];
|
||||||
|
if(i==group)
|
||||||
|
{
|
||||||
|
--newshape[i+1].number;
|
||||||
|
flatindex += index;
|
||||||
|
}
|
||||||
|
else flatindex += shape[i].number;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor<T> r(newshape);
|
||||||
|
if(r.rank()!=rank()) laerror("internal error 2 in unwind_index");
|
||||||
|
|
||||||
|
//compute the corresponding permutation of FLATINDEX for use in the callback
|
||||||
|
NRPerm<int> indexperm(rank());
|
||||||
|
indexperm[1]=flatindex+1;
|
||||||
|
int ii=1;
|
||||||
|
for(int i=2; i<=rank(); ++i)
|
||||||
|
{
|
||||||
|
indexperm[i] = ii++;
|
||||||
|
if(ii==flatindex+1) ii++; //skip this
|
||||||
|
}
|
||||||
|
if(!indexperm.is_valid()) laerror("internal error 3 in unwind_index");
|
||||||
|
|
||||||
|
//loop recursively and do the unwinding
|
||||||
|
help_tt<T> = this;
|
||||||
|
help_p = &indexperm;
|
||||||
|
r.loopover(unwind_callback);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template class Tensor<double>;
|
template class Tensor<double>;
|
||||||
template class Tensor<std::complex<double> >;
|
template class Tensor<std::complex<double> >;
|
||||||
|
21
tensor.h
21
tensor.h
@ -62,9 +62,9 @@ typedef int LA_largeindex;
|
|||||||
typedef class indexgroup {
|
typedef class indexgroup {
|
||||||
public:
|
public:
|
||||||
int number; //number of indices
|
int number; //number of indices
|
||||||
int symmetry; //-1 0 or 1
|
int symmetry; //-1 0 or 1, later 2 for hermitian and -2 for antihermitian? - would need change in operator() and Signedpointer
|
||||||
#ifdef LA_TENSOR_ZERO_OFFSET
|
#ifdef LA_TENSOR_ZERO_OFFSET
|
||||||
static const LA_index offset = 0; //compiler can optimiza away some computations
|
static const LA_index offset = 0; //compiler can optimize away some computations
|
||||||
#else
|
#else
|
||||||
LA_index offset; //indices start at a general offset
|
LA_index offset; //indices start at a general offset
|
||||||
#endif
|
#endif
|
||||||
@ -99,6 +99,7 @@ typedef NRVec<LA_index> FLATINDEX; //all indices but in a single vector
|
|||||||
typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP structure
|
typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP structure
|
||||||
typedef NRVec<LA_largeindex> GROUPINDEX; //set of indices in the symmetry groups
|
typedef NRVec<LA_largeindex> GROUPINDEX; //set of indices in the symmetry groups
|
||||||
|
|
||||||
|
FLATINDEX superindex2flat(const SUPERINDEX &I);
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
class Tensor {
|
class Tensor {
|
||||||
@ -132,11 +133,11 @@ public:
|
|||||||
LA_largeindex size() const {return data.size();};
|
LA_largeindex size() const {return data.size();};
|
||||||
void copyonwrite() {shape.copyonwrite(); groupsizes.copyonwrite(); cumsizes.copyonwrite(); data.copyonwrite();};
|
void copyonwrite() {shape.copyonwrite(); groupsizes.copyonwrite(); cumsizes.copyonwrite(); data.copyonwrite();};
|
||||||
inline Signedpointer<T> lhs(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
inline Signedpointer<T> lhs(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
||||||
inline T operator()(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
inline T operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||||
inline Signedpointer<T> lhs(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
inline Signedpointer<T> lhs(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
||||||
inline T operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
inline T operator()(const FLATINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||||
inline Signedpointer<T> lhs(LA_index i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
|
inline Signedpointer<T> lhs(LA_index i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
|
||||||
inline T operator()(LA_index i1...) {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
inline T operator()(LA_index i1...) const {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||||
|
|
||||||
inline Tensor& operator=(const Tensor &rhs) {myrank=rhs.myrank; shape=rhs.shape; groupsizes=rhs.groupsizes; cumsizes=rhs.cumsizes; data=rhs.data; return *this;};
|
inline Tensor& operator=(const Tensor &rhs) {myrank=rhs.myrank; shape=rhs.shape; groupsizes=rhs.groupsizes; cumsizes=rhs.cumsizes; data=rhs.data; return *this;};
|
||||||
|
|
||||||
@ -177,13 +178,11 @@ public:
|
|||||||
void grouploopover(void (*callback)(const GROUPINDEX &, T *)); //loop over all elements disregarding the internal structure of index groups
|
void grouploopover(void (*callback)(const GROUPINDEX &, T *)); //loop over all elements disregarding the internal structure of index groups
|
||||||
|
|
||||||
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
|
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
|
||||||
|
Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one
|
||||||
|
|
||||||
//@@@TODO - unwinding to full size in a specified index
|
//@@@ general antisymmetrization operator Kucharski style - or that will be left to a code generator?
|
||||||
//@@@contraction by a whole index group or by individual single index
|
//@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric group - obecne symmetry change krome +1 na -1 vse mozne
|
||||||
//@@@ general antisymmetrization operator Kucharski style
|
//@@@contraction
|
||||||
//@@@TODO - contractions - basic and efficient? first contraction in a single index; between a given group+index in group at each tensor
|
|
||||||
//@@@outer product and product with a contraction
|
|
||||||
//@@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric grtoup - obecne symmetry change krome +1 na -1 vse mozne
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user