tensor: bugfix in unwind_index and linear_transform of single group

This commit is contained in:
2025-12-15 20:32:01 +01:00
parent 671b924c8c
commit 26ed939901
3 changed files with 118 additions and 2 deletions

View File

@@ -713,6 +713,9 @@ return q;
template<typename T>
Tensor<T> Tensor<T>::permute_index_groups(const NRPerm<int> &p) const
{
if(p.size()!=shape.size()) laerror("permutation size mismatch in permute_index_groups");
if(p.is_identity()) return *this;
//std::cout <<"permute_index_groups permutation = "<<p<<std::endl;
NRVec<INDEXGROUP> newshape=shape.permuted(p,true);
//std::cout <<"permute_index_groups newshape = "<<newshape<<std::endl;
@@ -836,6 +839,7 @@ if(group==0 && index==0 && shape[0].symmetry==0) //formal split of 1 index witho
}
if(shape[group].number==1) return unwind_index_group(group); //single index in the group
//general case - recalculate the shape and allocate the new tensor
NRVec<INDEXGROUP> newshape(shape.size()+1);
newshape[0].number=1;
@@ -856,7 +860,10 @@ for(int i=0; i<shape.size(); ++i)
--newshape[i+1].number;
flatindex += index;
}
else flatindex += shape[i].number;
else
{
if(i<group) flatindex += shape[i].number;
}
}
@@ -1840,6 +1847,10 @@ for(int g=shape.size()-1; g>=0; --g)
Tensor<T> mat(x[g],true); //flat tensor from a matrix
for(int i=shape[g].number-1; i>=0; --i) //indices in the group in reverse order
{
#ifdef LA_TENSOR_INDEXPOSITION
//what should we do with index position?
//either set upperindex of mat appropriately, or request ignoring that in contractions?
#endif
r= tmp.contraction(indexposition(rank()-1,tmp.shape),mat,INDEX(0,0),(T)1,false,false); //always the last index
if(i>0)
{
@@ -1870,6 +1881,59 @@ return r;
}
template<typename T>
Tensor<T> Tensor<T>::linear_transform(const int g, const Tensor<T> &x) const
{
if(g<0||g>=shape.size()) laerror("wrong index group number in linear_transform");
if(x.rank()!=2 || x.shape.size()!=2 || x.shape[0].number!=1||x.shape[1].number!=1) laerror("wrong tensor shape for linear_transform");
if(x.shape[0].range!=shape[g].range) laerror("index range mismatch in linear_transform");
Tensor<T> tmp(*this);
Tensor<T> r;
//contract all indices in the reverse order
int gnow=g;
for(int i=shape[g].number-1; i>=0; --i) //indices in the group in reverse order
{
r= tmp.contraction(INDEX(gnow,i),x,INDEX(0,0),(T)1,false,false);
++gnow; //new group number in r, one index was added left
if(i>0)
{
tmp=r;
r.deallocate();
}
}
//the group's indices are now individual ones leftmost in tensor r, restore group size and symmetry
if(shape[g].number>1)
{
INDEXLIST il(shape[g].number);
for(int i=0; i<shape[g].number; ++i)
{
il[i].group=i;
il[i].index=0;
}
r = r.merge_indices(il,shape[g].symmetry);
}
//permute the group (now 0) to its original position
if(g!=0)
{
NRPerm<int> p(r.shape.size());
for(int i=0; i<g; ++i) p[i+1] = (i+1)+1;
p[g+1]=1;
for(int i=g+1; i<r.shape.size();++i) p[i+1] = i+1;
//std::cout<<"perm "<<p;
r=r.permute_index_groups(p);
}
//preserve names
r.names=names;
return r;
}
template<typename T>
int Tensor<T>::findflatindex(const INDEXNAME nam) const