tensor class unwind_index

This commit is contained in:
2024-04-25 16:38:35 +02:00
parent da0b3116f6
commit 5c6cb43c61
3 changed files with 124 additions and 12 deletions

View File

@@ -521,6 +521,8 @@ loopovergroups(*this,shape.size()-1,&pp,I,callback);
const NRPerm<int> *help_p;
template<typename T>
Tensor<T> *help_t;
template<typename T>
const Tensor<T> *help_tt;
template<typename T>
static void permutecallback(const GROUPINDEX &I, T *v)
@@ -551,6 +553,94 @@ return r;
}
FLATINDEX superindex2flat(const SUPERINDEX &I)
{
int rank=0;
for(int i=0; i<I.size(); ++i) rank += I[i].size();
FLATINDEX J(rank);
int ii=0;
for(int i=0; i<I.size(); ++i)
{
for(int j=0; j<I[i].size(); ++j) J[ii++] = I[i][j];
}
return J;
}
template<typename T>
static void unwind_callback(const SUPERINDEX &I, T *v)
{
FLATINDEX J = superindex2flat(I);
FLATINDEX JP = J.permuted(*help_p,true);
//std::cout <<"TEST unwind_callback: from "<<JP<<" TO "<<J<<std::endl;
*v = (*help_tt<T>)(JP); //rhs operator() generates the redundant elements for the unwinded lhs tensor
}
template<typename T>
Tensor<T> Tensor<T>::unwind_index(int group, int index) const
{
if(group<0||group>=shape.size()) laerror("wrong group number in unwind_index");
if(index<0||index>=shape[group].number) laerror("wrong index number in unwind_index");
if(shape[group].number==1) //single index in the group
{
if(group==0) return *this; //is already the least significant group
NRPerm<int> p(shape.size());
p[1]= 1+group;
int ii=1;
for(int i=2; i<=shape.size(); ++i)
{
p[i]=ii++;
if(ii==1+group) ii++; //skip this
}
if(!p.is_valid()) laerror("internal error in unwind_index");
return permute_index_groups(p);
}
//general case - recalculate the shape and allocate the new tensor
NRVec<indexgroup> newshape(shape.size()+1);
newshape[0].number=1;
newshape[0].symmetry=0;
newshape[0].range=shape[group].range;
#ifndef LA_TENSOR_ZERO_OFFSET
newshape[0].offset = shape[group].offset;
#endif
int flatindex=0; //(group,index) in flat form
for(int i=0; i<shape.size(); ++i)
{
newshape[i+1] = shape[i];
if(i==group)
{
--newshape[i+1].number;
flatindex += index;
}
else flatindex += shape[i].number;
}
Tensor<T> r(newshape);
if(r.rank()!=rank()) laerror("internal error 2 in unwind_index");
//compute the corresponding permutation of FLATINDEX for use in the callback
NRPerm<int> indexperm(rank());
indexperm[1]=flatindex+1;
int ii=1;
for(int i=2; i<=rank(); ++i)
{
indexperm[i] = ii++;
if(ii==flatindex+1) ii++; //skip this
}
if(!indexperm.is_valid()) laerror("internal error 3 in unwind_index");
//loop recursively and do the unwinding
help_tt<T> = this;
help_p = &indexperm;
r.loopover(unwind_callback);
return r;
}
template class Tensor<double>;
template class Tensor<std::complex<double> >;