From 26ed939901fab6d2f28b4feeb8f865345713ed19 Mon Sep 17 00:00:00 2001 From: Jiri Pittner Date: Mon, 15 Dec 2025 20:32:01 +0100 Subject: [PATCH] tensor: bugfix in unwind_index and linear_transform of single group --- t.cc | 52 ++++++++++++++++++++++++++++++++++++++++++- tensor.cc | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- tensor.h | 2 ++ 3 files changed, 118 insertions(+), 2 deletions(-) diff --git a/t.cc b/t.cc index afbfa31..b27c65a 100644 --- a/t.cc +++ b/t.cc @@ -4534,7 +4534,7 @@ for(int i=0; i>r>>n>>sym; +NRVec shape(3); + shape[0].number=2; + shape[0].symmetry=0; + shape[0].range=n+1; + shape[0].offset=0; + + shape[1].number=r; + shape[1].symmetry= sym; + shape[1].range=n; + shape[1].offset=0; + + shape[2].number=2; + shape[2].symmetry=0; + shape[2].range=n+2; + shape[2].offset=0; + + +Tensor x(shape); x.randomize(1.); +x.defaultnames(); +cout <<"x= "< yshape(2); + yshape[0].number=1; + yshape[0].symmetry=0; + yshape[0].range=n; + yshape[0].offset=0; + + yshape[1].number=1; + yshape[1].symmetry= 0; + yshape[1].range=n+3; + yshape[1].offset=0; + +Tensor y(yshape); y.randomize(1.); + +INDEX posit(1,0); +//Tensor z=x.unwind_index(1,1); +//Tensor z=x.contraction(INDEX(1,1),y,INDEX(0,0),1,false,false); +Tensor z=x.linear_transform(1,y); + +cout < Tensor Tensor::permute_index_groups(const NRPerm &p) const { +if(p.size()!=shape.size()) laerror("permutation size mismatch in permute_index_groups"); +if(p.is_identity()) return *this; + //std::cout <<"permute_index_groups permutation = "< newshape=shape.permuted(p,true); //std::cout <<"permute_index_groups newshape = "< newshape(shape.size()+1); newshape[0].number=1; @@ -856,7 +860,10 @@ for(int i=0; i=0; --g) Tensor mat(x[g],true); //flat tensor from a matrix for(int i=shape[g].number-1; i>=0; --i) //indices in the group in reverse order { +#ifdef LA_TENSOR_INDEXPOSITION + //what should we do with index position? + //either set upperindex of mat appropriately, or request ignoring that in contractions? +#endif r= tmp.contraction(indexposition(rank()-1,tmp.shape),mat,INDEX(0,0),(T)1,false,false); //always the last index if(i>0) { @@ -1870,6 +1881,59 @@ return r; } +template +Tensor Tensor::linear_transform(const int g, const Tensor &x) const +{ +if(g<0||g>=shape.size()) laerror("wrong index group number in linear_transform"); +if(x.rank()!=2 || x.shape.size()!=2 || x.shape[0].number!=1||x.shape[1].number!=1) laerror("wrong tensor shape for linear_transform"); +if(x.shape[0].range!=shape[g].range) laerror("index range mismatch in linear_transform"); + +Tensor tmp(*this); +Tensor r; + + +//contract all indices in the reverse order + int gnow=g; + for(int i=shape[g].number-1; i>=0; --i) //indices in the group in reverse order + { + r= tmp.contraction(INDEX(gnow,i),x,INDEX(0,0),(T)1,false,false); + ++gnow; //new group number in r, one index was added left + if(i>0) + { + tmp=r; + r.deallocate(); + } + } +//the group's indices are now individual ones leftmost in tensor r, restore group size and symmetry + if(shape[g].number>1) + { + INDEXLIST il(shape[g].number); + for(int i=0; i p(r.shape.size()); + for(int i=0; i int Tensor::findflatindex(const INDEXNAME nam) const diff --git a/tensor.h b/tensor.h index 7b6d379..c4bbc35 100644 --- a/tensor.h +++ b/tensor.h @@ -421,6 +421,8 @@ public: NRVec > Tucker(typename LA_traits::normtype thr=1e-12, bool inverseorder=false); //HOSVD-Tucker decomposition, return core tensor in *this, flattened Tensor inverseTucker(const NRVec > &x, bool inverseorder=false) const; //rebuild the original tensor from Tucker Tensor linear_transform(const NRVec > &x) const; //linear transform by a different matrix per each index group, preserving group symmetries + Tensor linear_transform(const int g, const Tensor &x) const; //linear transform a single group, preserve its position and symmetry, x must be rank-2 flat tensor (this is useful for raising/lowering indices) + Tensor linear_transform(const int g, const NRMat &x) const {Tensor mat(x,true); return linear_transform(g,mat);}; //linear transform a single group, preserve its position and symmetry };