From 0b91e88dca913301ee0ea7ebd5f4271033c4aff6 Mon Sep 17 00:00:00 2001 From: Jiri Pittner Date: Thu, 16 May 2024 18:23:30 +0200 Subject: [PATCH] tensor: contractions over severeal indices implemented --- t.cc | 37 ++++++++++- tensor.cc | 188 +++++++++++++++++++++++++++++++++++++++++++++++++++++- tensor.h | 20 +++++- vec.cc | 1 + 4 files changed, 239 insertions(+), 7 deletions(-) diff --git a/t.cc b/t.cc index 1f9c871..0f2f662 100644 --- a/t.cc +++ b/t.cc @@ -3316,7 +3316,12 @@ bg.range=n; Tensor b(bg); b.randomize(1.); -Tensor cc = a.contraction(0,0,b,0,1); +INDEXLIST il1(1); +il1[0]={0,0}; +INDEXLIST il2(1); +il2[0]={0,1}; +Tensor cc = a.contractions(il1,b,il2); +//Tensor cc = a.contraction(0,0,b,0,1); cout <1e-13) laerror("internal error in conntraction"); + if(abs(c(m,l,k,j,i)-cc(m,l,k,j,i))>1e-13) laerror("internal error in contraction"); } //cout < e(g); +e.randomize(1.); +INDEXLIST il(2); +il[0]= {0,1}; +il[1]= {0,3}; +Tensor eu = e.unwind_indices(il); + +for(int i=0; i &shape) +{ +int ii=0; +for(int g=0; g &shape) +{ +int ii=0; +for(int g=0; g static void unwind_callback(const SUPERINDEX &I, T *v) @@ -648,7 +664,119 @@ return r; template -static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0, bool conjugate=false) //R(nn,mm) = A * B^T +Tensor Tensor::unwind_indices(const INDEXLIST &il) const +{ +if(il.size()==0) return *this; +if(il.size()==1) return unwind_index(il[0].group,il[0].index); + +for(int i=0; i=shape.size()) laerror("wrong group number in unwind_indices"); + if(il[i].index<0||il[i].index>=shape[il[i].group].number) laerror("wrong index number in unwind_indices"); + } + +//all indices are solo in their groups - permute groups +bool sologroups=true; +int nonsolo=0; +for(int i=0; i p(shape.size()); + bitvector waslisted(shape.size()); + waslisted.clear(); + for(int i=0; i oldshape(shape); +oldshape.copyonwrite(); +NRVec newshape(shape.size()+nonsolo); + +//first the unwound indices as solo groups +for(int i=0; i0) + { + newshape[ii++] = oldshape[i]; + } + +Tensor r(newshape); +if(r.rank()!=rank()) laerror("internal error 2 in unwind_indces"); + +//compute the corresponding permutation of FLATINDEX for use in the callback +NRPerm indexperm(rank()); +bitvector waslisted(rank()); +waslisted.clear(); +//first unwound indices +ii=0; +for(int i=0; i = this; +help_p = &indexperm; +r.loopover(unwind_callback); +return r; +} + +template +static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0, bool conjugate=false) //R(nn,mm) = A(nn,kk) * B^T(mm,kk) { for(int i=0; i -void Tensor::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize, bool conjugate) +void Tensor::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize, bool conjugate1, bool conjugate) { if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction"); if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction"); @@ -690,6 +818,7 @@ if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction"); Tensor u = rhs1.unwind_index(group,index); +if(conjugate1) u.conjugateme(); Tensor rhsu = rhs.unwind_index(rhsgroup,rhsindex); @@ -709,7 +838,7 @@ else } int nn,mm,kk; kk=u.groupsizes[0]; -if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction"); +if(kk!=rhsu.groupsizes[0]) laerror("internal error in addcontraction"); nn=1; for(int i=1; i(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta,conjugate); } +template +void Tensor::addcontractions(const Tensor &rhs1, const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha, T beta, bool doresize, bool conjugate1, bool conjugate2) +{ +if(il1.size()==0) laerror("empty contraction - outer product not implemented"); +if(il1.size()!=il2.size()) laerror("mismatch in index lists in addcontractions"); +for(int i=0; i=rhs1.shape.size()) laerror("wrong group1 number in contractions"); + if(il2[i].group<0||il2[i].group>=rhs2.shape.size()) laerror("wrong group2 number in contractions"); + if(il1[i].index<0||il1[i].index>=rhs1.shape[il1[i].group].number) laerror("wrong index1 number in conntractions"); + if(il2[i].index<0||il2[i].index>=rhs2.shape[il2[i].group].number) laerror("wrong index2 number in conntractions"); + if(rhs1.shape[il1[i].group].offset != rhs2.shape[il2[i].group].offset) laerror("incompatible index offset in contractions"); + if(rhs1.shape[il1[i].group].range != rhs2.shape[il2[i].group].range) laerror("incompatible index range in contractions"); + } + +Tensor u = rhs1.unwind_indices(il1); +if(conjugate1) u.conjugateme(); +Tensor rhsu = rhs2.unwind_indices(il2); + + +NRVec newshape(u.shape.size()+rhsu.shape.size()-2*il1.size()); +int ii=0; +for(int i=il1.size(); i(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta,conjugate2); +} + + + + template static const PermutationAlgebra *help_pa; diff --git a/tensor.h b/tensor.h index c71f86a..77180a9 100644 --- a/tensor.h +++ b/tensor.h @@ -36,6 +36,9 @@ #include "smat.h" #include "miscfunc.h" +//@@@todo - outer product +//@@@permutation of individual indices??? how to treat the symmetry groups +//@@@todo - index names and contraction by named index list namespace LA { @@ -98,6 +101,15 @@ class LA_traits { typedef NRVec FLATINDEX; //all indices but in a single vector typedef NRVec > SUPERINDEX; //all indices in the INDEXGROUP structure typedef NRVec GROUPINDEX; //set of indices in the symmetry groups +struct INDEX +{ +int group; +int index; +}; +typedef NRVec INDEXLIST; //collection of several indices + +int flatposition(const INDEX &i, const NRVec &shape); //position of that index in FLATINDEX +int flatposition(const INDEX &i, const NRVec &shape); FLATINDEX superindex2flat(const SUPERINDEX &I); @@ -184,8 +196,12 @@ public: Tensor permute_index_groups(const NRPerm &p) const; //rearrange the tensor storage permuting index groups as a whole Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one - void addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, T beta=1, bool doresize=false, bool conjugate=false); - inline Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, bool conjugate=false) const {Tensor r; r.addcontraction(*this,group,index,rhs,rhsgroup,rhsindex,alpha,0,true, conjugate); return r; } + Tensor unwind_indices(const INDEXLIST &il) const; //the same for a list of indices + void addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs2, int rhsgroup, int rhsindex, T alpha=1, T beta=1, bool doresize=false, bool conjugate1=false, bool conjugate=false); //rhs1 will have more significant non-contracted indices in the result than rhs2 + inline Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, bool conjugate1=false, bool conjugate=false) const {Tensor r; r.addcontraction(*this,group,index,rhs,rhsgroup,rhsindex,alpha,0,true, conjugate1, conjugate); return r; }; + + void addcontractions(const Tensor &rhs1, const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha=1, T beta=1, bool doresize=false, bool conjugate1=false, bool conjugate2=false); + inline Tensor contractions( const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha=1, bool conjugate1=false, bool conjugate2=false) const {Tensor r; r.addcontractions(*this,il1,rhs2,il2,alpha,0,true,conjugate1, conjugate2); return r; }; void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list: // this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I)) diff --git a/vec.cc b/vec.cc index af8283e..a048009 100644 --- a/vec.cc +++ b/vec.cc @@ -909,6 +909,7 @@ void NRVec::storesubvector(const NRVec &selection, const NRVec &rhs) ******************************************************************************/ template NRVec& NRVec::conjugateme() { +copyonwrite(); #ifdef CUDALA if(location != cpu) laerror("general conjugation only on CPU"); #endif