tensor class: addcontraction
This commit is contained in:
parent
27cc7854f5
commit
3d8386e30a
26
tensor.cc
26
tensor.cc
@ -679,16 +679,16 @@ cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, &alpha, a, kk,
|
||||
//The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices
|
||||
//
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha) const
|
||||
void Tensor<T>::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize)
|
||||
{
|
||||
if(group<0||group>=shape.size()) laerror("wrong group number in contraction");
|
||||
if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction");
|
||||
if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction");
|
||||
if(index<0||index>=shape[group].number) laerror("wrong index number in conntraction");
|
||||
if(index<0||index>=rhs1.shape[group].number) laerror("wrong index number in conntraction");
|
||||
if(rhsindex<0||rhsindex>=rhs.shape[rhsgroup].number) laerror("wrong index number in conntraction");
|
||||
if(shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction");
|
||||
if(shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
|
||||
if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction");
|
||||
if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
|
||||
|
||||
Tensor<T> u = unwind_index(group,index);
|
||||
Tensor<T> u = rhs1.unwind_index(group,index);
|
||||
Tensor<T> rhsu = rhs.unwind_index(rhsgroup,rhsindex);
|
||||
|
||||
|
||||
@ -697,18 +697,26 @@ int ii=0;
|
||||
for(int i=1; i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
|
||||
for(int i=1; i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
|
||||
|
||||
Tensor<T> r(newshape);
|
||||
if(doresize)
|
||||
{
|
||||
if(beta!= (T)0) laerror("resize in addcontraction requires beta=0");
|
||||
resize(newshape);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(shape!=newshape) laerror("tensor shape mismatch in addcontraction");
|
||||
}
|
||||
int nn,mm,kk;
|
||||
kk=u.groupsizes[0];
|
||||
if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction");
|
||||
nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i];
|
||||
mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
|
||||
auxmatmult<T>(nn,mm,kk,&r.data[0],&u.data[0], &rhsu.data[0],alpha);
|
||||
return r;
|
||||
auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template class Tensor<double>;
|
||||
template class Tensor<std::complex<double> >;
|
||||
template std::ostream & operator<<(std::ostream &s, const Tensor<double> &x);
|
||||
|
4
tensor.h
4
tensor.h
@ -132,6 +132,7 @@ public:
|
||||
LA_largeindex calcsize(); //set redundant data and return total size
|
||||
LA_largeindex size() const {return data.size();};
|
||||
void copyonwrite() {shape.copyonwrite(); groupsizes.copyonwrite(); cumsizes.copyonwrite(); data.copyonwrite();};
|
||||
void resize(const NRVec<indexgroup> &s) {shape=s; data.resize(calcsize()); calcrank();};
|
||||
inline Signedpointer<T> lhs(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
||||
inline T operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
|
||||
inline Signedpointer<T> lhs(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
|
||||
@ -179,7 +180,8 @@ public:
|
||||
|
||||
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
|
||||
Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one
|
||||
Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1) const;
|
||||
void addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, T beta=1, bool doresize=false);
|
||||
inline Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1) const {Tensor<T> r; r.addcontraction(*this,group,index,rhs,rhsgroup,rhsindex,alpha,0,true); return r; }
|
||||
|
||||
//@@@ general antisymmetrization operator Kucharski style - or that will be left to a code generator?
|
||||
//@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric group - obecne symmetry change krome +1 na -1 vse mozne
|
||||
|
Loading…
Reference in New Issue
Block a user