From 3d8386e30a322d0f4f4d7b67b21f4083022589b1 Mon Sep 17 00:00:00 2001 From: Jiri Pittner Date: Tue, 30 Apr 2024 16:38:16 +0200 Subject: [PATCH] tensor class: addcontraction --- tensor.cc | 26 +++++++++++++++++--------- tensor.h | 4 +++- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/tensor.cc b/tensor.cc index 78e27c7..52b1785 100644 --- a/tensor.cc +++ b/tensor.cc @@ -679,16 +679,16 @@ cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, &alpha, a, kk, //The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices // template -Tensor Tensor::contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha) const +void Tensor::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize) { -if(group<0||group>=shape.size()) laerror("wrong group number in contraction"); +if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction"); if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction"); -if(index<0||index>=shape[group].number) laerror("wrong index number in conntraction"); +if(index<0||index>=rhs1.shape[group].number) laerror("wrong index number in conntraction"); if(rhsindex<0||rhsindex>=rhs.shape[rhsgroup].number) laerror("wrong index number in conntraction"); -if(shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction"); -if(shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction"); +if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction"); +if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction"); -Tensor u = unwind_index(group,index); +Tensor u = rhs1.unwind_index(group,index); Tensor rhsu = rhs.unwind_index(rhsgroup,rhsindex); @@ -697,18 +697,26 @@ int ii=0; for(int i=1; i r(newshape); +if(doresize) + { + if(beta!= (T)0) laerror("resize in addcontraction requires beta=0"); + resize(newshape); + } +else + { + if(shape!=newshape) laerror("tensor shape mismatch in addcontraction"); + } int nn,mm,kk; kk=u.groupsizes[0]; if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction"); nn=1; for(int i=1; i(nn,mm,kk,&r.data[0],&u.data[0], &rhsu.data[0],alpha); -return r; +auxmatmult(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta); } + template class Tensor; template class Tensor >; template std::ostream & operator<<(std::ostream &s, const Tensor &x); diff --git a/tensor.h b/tensor.h index 7fe4a69..6f727c0 100644 --- a/tensor.h +++ b/tensor.h @@ -132,6 +132,7 @@ public: LA_largeindex calcsize(); //set redundant data and return total size LA_largeindex size() const {return data.size();}; void copyonwrite() {shape.copyonwrite(); groupsizes.copyonwrite(); cumsizes.copyonwrite(); data.copyonwrite();}; + void resize(const NRVec &s) {shape=s; data.resize(calcsize()); calcrank();}; inline Signedpointer lhs(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer(&data[i],sign);}; inline T operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];}; inline Signedpointer lhs(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer(&data[i],sign);}; @@ -179,7 +180,8 @@ public: Tensor permute_index_groups(const NRPerm &p) const; //rearrange the tensor storage permuting index groups as a whole Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one - Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1) const; + void addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, T beta=1, bool doresize=false); + inline Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1) const {Tensor r; r.addcontraction(*this,group,index,rhs,rhsgroup,rhsindex,alpha,0,true); return r; } //@@@ general antisymmetrization operator Kucharski style - or that will be left to a code generator? //@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric group - obecne symmetry change krome +1 na -1 vse mozne