tensor class unwind_index
This commit is contained in:
90
tensor.cc
90
tensor.cc
@@ -521,6 +521,8 @@ loopovergroups(*this,shape.size()-1,&pp,I,callback);
|
||||
const NRPerm<int> *help_p;
|
||||
template<typename T>
|
||||
Tensor<T> *help_t;
|
||||
template<typename T>
|
||||
const Tensor<T> *help_tt;
|
||||
|
||||
template<typename T>
|
||||
static void permutecallback(const GROUPINDEX &I, T *v)
|
||||
@@ -551,6 +553,94 @@ return r;
|
||||
}
|
||||
|
||||
|
||||
FLATINDEX superindex2flat(const SUPERINDEX &I)
|
||||
{
|
||||
int rank=0;
|
||||
for(int i=0; i<I.size(); ++i) rank += I[i].size();
|
||||
FLATINDEX J(rank);
|
||||
int ii=0;
|
||||
for(int i=0; i<I.size(); ++i)
|
||||
{
|
||||
for(int j=0; j<I[i].size(); ++j) J[ii++] = I[i][j];
|
||||
}
|
||||
return J;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static void unwind_callback(const SUPERINDEX &I, T *v)
|
||||
{
|
||||
FLATINDEX J = superindex2flat(I);
|
||||
FLATINDEX JP = J.permuted(*help_p,true);
|
||||
//std::cout <<"TEST unwind_callback: from "<<JP<<" TO "<<J<<std::endl;
|
||||
*v = (*help_tt<T>)(JP); //rhs operator() generates the redundant elements for the unwinded lhs tensor
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::unwind_index(int group, int index) const
|
||||
{
|
||||
if(group<0||group>=shape.size()) laerror("wrong group number in unwind_index");
|
||||
if(index<0||index>=shape[group].number) laerror("wrong index number in unwind_index");
|
||||
if(shape[group].number==1) //single index in the group
|
||||
{
|
||||
if(group==0) return *this; //is already the least significant group
|
||||
NRPerm<int> p(shape.size());
|
||||
p[1]= 1+group;
|
||||
int ii=1;
|
||||
for(int i=2; i<=shape.size(); ++i)
|
||||
{
|
||||
p[i]=ii++;
|
||||
if(ii==1+group) ii++; //skip this
|
||||
}
|
||||
if(!p.is_valid()) laerror("internal error in unwind_index");
|
||||
return permute_index_groups(p);
|
||||
}
|
||||
|
||||
//general case - recalculate the shape and allocate the new tensor
|
||||
NRVec<indexgroup> newshape(shape.size()+1);
|
||||
newshape[0].number=1;
|
||||
newshape[0].symmetry=0;
|
||||
newshape[0].range=shape[group].range;
|
||||
#ifndef LA_TENSOR_ZERO_OFFSET
|
||||
newshape[0].offset = shape[group].offset;
|
||||
#endif
|
||||
int flatindex=0; //(group,index) in flat form
|
||||
for(int i=0; i<shape.size(); ++i)
|
||||
{
|
||||
newshape[i+1] = shape[i];
|
||||
if(i==group)
|
||||
{
|
||||
--newshape[i+1].number;
|
||||
flatindex += index;
|
||||
}
|
||||
else flatindex += shape[i].number;
|
||||
}
|
||||
|
||||
Tensor<T> r(newshape);
|
||||
if(r.rank()!=rank()) laerror("internal error 2 in unwind_index");
|
||||
|
||||
//compute the corresponding permutation of FLATINDEX for use in the callback
|
||||
NRPerm<int> indexperm(rank());
|
||||
indexperm[1]=flatindex+1;
|
||||
int ii=1;
|
||||
for(int i=2; i<=rank(); ++i)
|
||||
{
|
||||
indexperm[i] = ii++;
|
||||
if(ii==flatindex+1) ii++; //skip this
|
||||
}
|
||||
if(!indexperm.is_valid()) laerror("internal error 3 in unwind_index");
|
||||
|
||||
//loop recursively and do the unwinding
|
||||
help_tt<T> = this;
|
||||
help_p = &indexperm;
|
||||
r.loopover(unwind_callback);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template class Tensor<double>;
|
||||
template class Tensor<std::complex<double> >;
|
||||
|
||||
Reference in New Issue
Block a user