From 27cc7854f5b7ec2b6f38e713de730a36cab4d350 Mon Sep 17 00:00:00 2001 From: Jiri Pittner Date: Thu, 25 Apr 2024 18:09:05 +0200 Subject: [PATCH] tensor class -contraction --- t.cc | 58 +++++++++++++++++++++++++++++++++++++++++++++- tensor.cc | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- tensor.h | 2 +- 3 files changed, 126 insertions(+), 3 deletions(-) diff --git a/t.cc b/t.cc index a865b14..22bc995 100644 --- a/t.cc +++ b/t.cc @@ -3270,7 +3270,7 @@ for(int i=0; i a(ag); +a.randomize(1.); + +INDEXGROUP bg; +bg.number=3; +bg.symmetry= 0; +bg.offset=0; +bg.range=n; + +Tensor b(bg); +b.randomize(1.); + +Tensor cc = a.contraction(0,0,b,0,1); +cout < shape({cgb,cga}); + +Tensor c(shape); +c.clear(); + +for(int i=0; i1e-13) laerror("internal error in conntraction"); + } + +//cout < p(shape.size()); p[1]= 1+group; int ii=1; + if(ii==1+group) ii++; //skip this for(int i=2; i<=shape.size(); ++i) { p[i]=ii++; @@ -625,12 +626,17 @@ if(r.rank()!=rank()) laerror("internal error 2 in unwind_index"); NRPerm indexperm(rank()); indexperm[1]=flatindex+1; int ii=1; +if(ii==flatindex+1) ii++; for(int i=2; i<=rank(); ++i) { indexperm[i] = ii++; if(ii==flatindex+1) ii++; //skip this } -if(!indexperm.is_valid()) laerror("internal error 3 in unwind_index"); +if(!indexperm.is_valid()) + { + std::cout << "indexperm = "< = this; @@ -640,6 +646,67 @@ return r; } +template +static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0) //R(nn,mm) = A * B^T +{ +for(int i=0; i +void auxmatmult(int nn, int mm, int kk, double *r, double *a, double *b, double alpha, double beta) +{ +cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, alpha, a, kk, b, kk, beta, r, mm); +} + +template<> +void auxmatmult >(int nn, int mm, int kk, std::complex *r, std::complex *a, std::complex *b, std::complex alpha, std::complex beta) +{ +cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, &alpha, a, kk, b, kk, &beta, r, mm); +} + + + + + +//Conntraction could be implemented without the temporary storage for unwinding, but then we would need +//double recursion over indices of both tensors. Hopefully using the matrix multiplication here +//makes it also more efficient, even for (anti)symmetric indices +//The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices +// +template +Tensor Tensor::contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha) const +{ +if(group<0||group>=shape.size()) laerror("wrong group number in contraction"); +if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction"); +if(index<0||index>=shape[group].number) laerror("wrong index number in conntraction"); +if(rhsindex<0||rhsindex>=rhs.shape[rhsgroup].number) laerror("wrong index number in conntraction"); +if(shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction"); +if(shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction"); + +Tensor u = unwind_index(group,index); +Tensor rhsu = rhs.unwind_index(rhsgroup,rhsindex); + + +NRVec newshape(u.shape.size()+rhsu.shape.size()-2); +int ii=0; +for(int i=1; i r(newshape); +int nn,mm,kk; +kk=u.groupsizes[0]; +if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction"); +nn=1; for(int i=1; i(nn,mm,kk,&r.data[0],&u.data[0], &rhsu.data[0],alpha); +return r; +} + template class Tensor; diff --git a/tensor.h b/tensor.h index d4931da..7fe4a69 100644 --- a/tensor.h +++ b/tensor.h @@ -179,10 +179,10 @@ public: Tensor permute_index_groups(const NRPerm &p) const; //rearrange the tensor storage permuting index groups as a whole Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one + Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1) const; //@@@ general antisymmetrization operator Kucharski style - or that will be left to a code generator? //@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric group - obecne symmetry change krome +1 na -1 vse mozne - //@@@contraction };