tensor class unwind_index
This commit is contained in:
parent
da0b3116f6
commit
5c6cb43c61
25
t.cc
25
t.cc
@ -3243,7 +3243,7 @@ cout <<epsilon.data;
|
||||
cout <<epsilon;
|
||||
}
|
||||
|
||||
if(1)
|
||||
if(0)
|
||||
{
|
||||
int n=3;
|
||||
NRVec<INDEXGROUP> s(4);
|
||||
@ -3270,5 +3270,28 @@ for(int i=0; i<n; ++i)
|
||||
}
|
||||
|
||||
|
||||
if(1)
|
||||
{
|
||||
int n=5;
|
||||
INDEXGROUP g;
|
||||
g.number=4;
|
||||
g.symmetry= -1;
|
||||
g.offset=0;
|
||||
g.range=n;
|
||||
|
||||
Tensor<double> e(g);
|
||||
e.randomize(1.);
|
||||
Tensor<double> eu = e.unwind_index(0,1);
|
||||
|
||||
for(int i=0; i<n; ++i)
|
||||
for(int j=0; j<n; ++j)
|
||||
for(int k=0; k<n; ++k)
|
||||
for(int l=0; l<n; ++l)
|
||||
{
|
||||
if(e(i,j,k,l)!=eu(j,i,k,l)) laerror("error in unwind_index");
|
||||
}
|
||||
cout <<e;
|
||||
cout <<eu;
|
||||
}
|
||||
|
||||
}
|
||||
|
90
tensor.cc
90
tensor.cc
@ -521,6 +521,8 @@ loopovergroups(*this,shape.size()-1,&pp,I,callback);
|
||||
const NRPerm<int> *help_p;
|
||||
template<typename T>
|
||||
Tensor<T> *help_t;
|
||||
template<typename T>
|
||||
const Tensor<T> *help_tt;
|
||||
|
||||
template<typename T>
|
||||
static void permutecallback(const GROUPINDEX &I, T *v)
|
||||
@ -551,6 +553,94 @@ return r;
|
||||
}
|
||||
|
||||
|
||||
FLATINDEX superindex2flat(const SUPERINDEX &I)
|
||||
{
|
||||
int rank=0;
|
||||
for(int i=0; i<I.size(); ++i) rank += I[i].size();
|
||||
FLATINDEX J(rank);
|
||||
int ii=0;
|
||||
for(int i=0; i<I.size(); ++i)
|
||||
{
|
||||
for(int j=0; j<I[i].size(); ++j) J[ii++] = I[i][j];
|
||||
}
|
||||
return J;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static void unwind_callback(const SUPERINDEX &I, T *v)
|
||||
{
|
||||
FLATINDEX J = superindex2flat(I);
|
||||
FLATINDEX JP = J.permuted(*help_p,true);
|
||||
//std::cout <<"TEST unwind_callback: from "<<JP<<" TO "<<J<<std::endl;
|
||||
*v = (*help_tt<T>)(JP); //rhs operator() generates the redundant elements for the unwinded lhs tensor
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::unwind_index(int group, int index) const
|
||||
{
|
||||
if(group<0||group>=shape.size()) laerror("wrong group number in unwind_index");
|
||||
if(index<0||index>=shape[group].number) laerror("wrong index number in unwind_index");
|
||||
if(shape[group].number==1) //single index in the group
|
||||
{
|
||||
if(group==0) return *this; //is already the least significant group
|
||||
NRPerm<int> p(shape.size());
|
||||
p[1]= 1+group;
|
||||
int ii=1;
|
||||
for(int i=2; i<=shape.size(); ++i)
|
||||
{
|
||||
p[i]=ii++;
|
||||
if(ii==1+group) ii++; //skip this
|
||||
}
|
||||
if(!p.is_valid()) laerror("internal error in unwind_index");
|
||||
return permute_index_groups(p);
|
||||
}
|
||||
|
||||
//general case - recalculate the shape and allocate the new tensor
|
||||
NRVec<indexgroup> newshape(shape.size()+1);
|
||||
newshape[0].number=1;
|
||||
newshape[0].symmetry=0;
|
||||
newshape[0].range=shape[group].range;
|
||||
#ifndef LA_TENSOR_ZERO_OFFSET
|
||||
newshape[0].offset = shape[group].offset;
|
||||
#endif
|
||||
int flatindex=0; //(group,index) in flat form
|
||||
for(int i=0; i<shape.size(); ++i)
|
||||
{
|
||||
newshape[i+1] = shape[i];
|
||||
if(i==group)
|
||||
{
|
||||
--newshape[i+1].number;
|
||||
flatindex += index;
|
||||
}
|
||||
else flatindex += shape[i].number;
|
||||
}
|
||||
|
||||
Tensor<T> r(newshape);
|
||||
if(r.rank()!=rank()) laerror("internal error 2 in unwind_index");
|
||||
|
||||
//compute the corresponding permutation of FLATINDEX for use in the callback
|
||||
NRPerm<int> indexperm(rank());
|
||||
indexperm[1]=flatindex+1;
|
||||
int ii=1;
|
||||
for(int i=2; i<=rank(); ++i)
|
||||
{
|
||||
indexperm[i] = ii++;
|
||||
if(ii==flatindex+1) ii++; //skip this
|
||||
}
|
||||
if(!indexperm.is_valid()) laerror("internal error 3 in unwind_index");
|
||||
|
||||
//loop recursively and do the unwinding
|
||||
help_tt<T> = this;
|
||||
help_p = &indexperm;
|
||||
r.loopover(unwind_callback);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template class Tensor<double>;
|
||||
template class Tensor<std::complex<double> >;
|
||||
|
21
tensor.h
21
tensor.h
@ -62,9 +62,9 @@ typedef int LA_largeindex;
|
||||
typedef class indexgroup {
|
||||
public:
|
||||
int number; //number of indices
|
||||
int symmetry; //-1 0 or 1
|
||||
int symmetry; //-1 0 or 1, later 2 for hermitian and -2 for antihermitian? - would need change in operator() and Signedpointer
|
||||
#ifdef LA_TENSOR_ZERO_OFFSET
|
||||
static const LA_index offset = 0; //compiler can optimiza away some computations
|
||||
static const LA_index offset = 0; //compiler can optimize away some computations
|
||||
#else
|
||||
LA_index offset; //indices start at a general offset
|
||||
#endif
|
||||
@ -99,6 +99,7 @@ typedef NRVec<LA_index> FLATINDEX; //all indices but in a single vector
|
||||
typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP structure
|
||||
typedef NRVec<LA_largeindex> GROUPINDEX; //set of indices in the symmetry groups
|
||||
|
||||
FLATINDEX superindex2flat(const SUPERINDEX &I);
|
||||
|
||||
template<typename T>
|
||||
class Tensor {
|
||||
@ -132,11 +133,11 @@ public:
|
||||
LA_largeindex size() const {return data.size();};
|
||||
void copyonwrite() {shape.copyonwrite(); groupsizes.copyonwrite(); cumsizes.copyonwrite(); data.copyonwrite();};
|
||||
inline Signedpointer<T> lhs(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
||||
inline T operator()(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||
inline T operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||
inline Signedpointer<T> lhs(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
||||
inline T operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||
inline T operator()(const FLATINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||
inline Signedpointer<T> lhs(LA_index i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
|
||||
inline T operator()(LA_index i1...) {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||
inline T operator()(LA_index i1...) const {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||
|
||||
inline Tensor& operator=(const Tensor &rhs) {myrank=rhs.myrank; shape=rhs.shape; groupsizes=rhs.groupsizes; cumsizes=rhs.cumsizes; data=rhs.data; return *this;};
|
||||
|
||||
@ -177,13 +178,11 @@ public:
|
||||
void grouploopover(void (*callback)(const GROUPINDEX &, T *)); //loop over all elements disregarding the internal structure of index groups
|
||||
|
||||
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
|
||||
Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one
|
||||
|
||||
//@@@TODO - unwinding to full size in a specified index
|
||||
//@@@contraction by a whole index group or by individual single index
|
||||
//@@@ general antisymmetrization operator Kucharski style
|
||||
//@@@TODO - contractions - basic and efficient? first contraction in a single index; between a given group+index in group at each tensor
|
||||
//@@@outer product and product with a contraction
|
||||
//@@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric grtoup - obecne symmetry change krome +1 na -1 vse mozne
|
||||
//@@@ general antisymmetrization operator Kucharski style - or that will be left to a code generator?
|
||||
//@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric group - obecne symmetry change krome +1 na -1 vse mozne
|
||||
//@@@contraction
|
||||
};
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user