tensor class: addcontraction

This commit is contained in:
Jiri Pittner 2024-04-30 16:38:16 +02:00
parent 27cc7854f5
commit 3d8386e30a
2 changed files with 20 additions and 10 deletions

View File

@ -679,16 +679,16 @@ cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, &alpha, a, kk,
//The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices //The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices
// //
template<typename T> template<typename T>
Tensor<T> Tensor<T>::contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha) const void Tensor<T>::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize)
{ {
if(group<0||group>=shape.size()) laerror("wrong group number in contraction"); if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction");
if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction"); if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction");
if(index<0||index>=shape[group].number) laerror("wrong index number in conntraction"); if(index<0||index>=rhs1.shape[group].number) laerror("wrong index number in conntraction");
if(rhsindex<0||rhsindex>=rhs.shape[rhsgroup].number) laerror("wrong index number in conntraction"); if(rhsindex<0||rhsindex>=rhs.shape[rhsgroup].number) laerror("wrong index number in conntraction");
if(shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction"); if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction");
if(shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction"); if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
Tensor<T> u = unwind_index(group,index); Tensor<T> u = rhs1.unwind_index(group,index);
Tensor<T> rhsu = rhs.unwind_index(rhsgroup,rhsindex); Tensor<T> rhsu = rhs.unwind_index(rhsgroup,rhsindex);
@ -697,18 +697,26 @@ int ii=0;
for(int i=1; i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i]; for(int i=1; i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
for(int i=1; i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one for(int i=1; i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
Tensor<T> r(newshape); if(doresize)
{
if(beta!= (T)0) laerror("resize in addcontraction requires beta=0");
resize(newshape);
}
else
{
if(shape!=newshape) laerror("tensor shape mismatch in addcontraction");
}
int nn,mm,kk; int nn,mm,kk;
kk=u.groupsizes[0]; kk=u.groupsizes[0];
if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction"); if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction");
nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i]; nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i];
mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i]; mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
auxmatmult<T>(nn,mm,kk,&r.data[0],&u.data[0], &rhsu.data[0],alpha); auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta);
return r;
} }
template class Tensor<double>; template class Tensor<double>;
template class Tensor<std::complex<double> >; template class Tensor<std::complex<double> >;
template std::ostream & operator<<(std::ostream &s, const Tensor<double> &x); template std::ostream & operator<<(std::ostream &s, const Tensor<double> &x);

View File

@ -132,6 +132,7 @@ public:
LA_largeindex calcsize(); //set redundant data and return total size LA_largeindex calcsize(); //set redundant data and return total size
LA_largeindex size() const {return data.size();}; LA_largeindex size() const {return data.size();};
void copyonwrite() {shape.copyonwrite(); groupsizes.copyonwrite(); cumsizes.copyonwrite(); data.copyonwrite();}; void copyonwrite() {shape.copyonwrite(); groupsizes.copyonwrite(); cumsizes.copyonwrite(); data.copyonwrite();};
void resize(const NRVec<indexgroup> &s) {shape=s; data.resize(calcsize()); calcrank();};
inline Signedpointer<T> lhs(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);}; inline Signedpointer<T> lhs(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
inline T operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];}; inline T operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
inline Signedpointer<T> lhs(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);}; inline Signedpointer<T> lhs(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
@ -179,7 +180,8 @@ public:
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one
Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1) const; void addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, T beta=1, bool doresize=false);
inline Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1) const {Tensor<T> r; r.addcontraction(*this,group,index,rhs,rhsgroup,rhsindex,alpha,0,true); return r; }
//@@@ general antisymmetrization operator Kucharski style - or that will be left to a code generator? //@@@ general antisymmetrization operator Kucharski style - or that will be left to a code generator?
//@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric group - obecne symmetry change krome +1 na -1 vse mozne //@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric group - obecne symmetry change krome +1 na -1 vse mozne