tensor class -contraction
This commit is contained in:
parent
5c6cb43c61
commit
27cc7854f5
58
t.cc
58
t.cc
@ -3270,7 +3270,7 @@ for(int i=0; i<n; ++i)
|
||||
}
|
||||
|
||||
|
||||
if(1)
|
||||
if(0)
|
||||
{
|
||||
int n=5;
|
||||
INDEXGROUP g;
|
||||
@ -3294,4 +3294,60 @@ cout <<e;
|
||||
cout <<eu;
|
||||
}
|
||||
|
||||
|
||||
if(1)
|
||||
{
|
||||
int n=5;
|
||||
INDEXGROUP ag;
|
||||
ag.number=4;
|
||||
ag.symmetry= 1;
|
||||
ag.offset=0;
|
||||
ag.range=n;
|
||||
|
||||
Tensor<double> a(ag);
|
||||
a.randomize(1.);
|
||||
|
||||
INDEXGROUP bg;
|
||||
bg.number=3;
|
||||
bg.symmetry= 0;
|
||||
bg.offset=0;
|
||||
bg.range=n;
|
||||
|
||||
Tensor<double> b(bg);
|
||||
b.randomize(1.);
|
||||
|
||||
Tensor<double> cc = a.contraction(0,0,b,0,1);
|
||||
cout <<cc;
|
||||
|
||||
INDEXGROUP cga;
|
||||
cga.number=3;
|
||||
cga.symmetry= 1;
|
||||
cga.offset=0;
|
||||
cga.range=n;
|
||||
|
||||
INDEXGROUP cgb;
|
||||
cgb.number=2;
|
||||
cgb.symmetry= 0;
|
||||
cgb.offset=0;
|
||||
cgb.range=n;
|
||||
|
||||
NRVec<INDEXGROUP> shape({cgb,cga});
|
||||
|
||||
Tensor<double> c(shape);
|
||||
c.clear();
|
||||
|
||||
for(int i=0; i<n; ++i)
|
||||
for(int j=0; j<=i; ++j)
|
||||
for(int k=0; k<=j; ++k)
|
||||
for(int l=0; l<n; ++l)
|
||||
for(int m=0; m<n; ++m)
|
||||
{
|
||||
for(int p=0; p<n; ++p)
|
||||
c.lhs(m,l,k,j,i) += a(p,i,j,k) * b(m,p,l);
|
||||
if(abs(c(m,l,k,j,i)-cc(m,l,k,j,i))>1e-13) laerror("internal error in conntraction");
|
||||
}
|
||||
|
||||
//cout <<c;
|
||||
}
|
||||
|
||||
}
|
||||
|
69
tensor.cc
69
tensor.cc
@ -589,6 +589,7 @@ if(shape[group].number==1) //single index in the group
|
||||
NRPerm<int> p(shape.size());
|
||||
p[1]= 1+group;
|
||||
int ii=1;
|
||||
if(ii==1+group) ii++; //skip this
|
||||
for(int i=2; i<=shape.size(); ++i)
|
||||
{
|
||||
p[i]=ii++;
|
||||
@ -625,12 +626,17 @@ if(r.rank()!=rank()) laerror("internal error 2 in unwind_index");
|
||||
NRPerm<int> indexperm(rank());
|
||||
indexperm[1]=flatindex+1;
|
||||
int ii=1;
|
||||
if(ii==flatindex+1) ii++;
|
||||
for(int i=2; i<=rank(); ++i)
|
||||
{
|
||||
indexperm[i] = ii++;
|
||||
if(ii==flatindex+1) ii++; //skip this
|
||||
}
|
||||
if(!indexperm.is_valid()) laerror("internal error 3 in unwind_index");
|
||||
if(!indexperm.is_valid())
|
||||
{
|
||||
std::cout << "indexperm = "<<indexperm<<std::endl;
|
||||
laerror("internal error 3 in unwind_index");
|
||||
}
|
||||
|
||||
//loop recursively and do the unwinding
|
||||
help_tt<T> = this;
|
||||
@ -640,6 +646,67 @@ return r;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0) //R(nn,mm) = A * B^T
|
||||
{
|
||||
for(int i=0; i<nn; ++i) for(int j=0; j<mm; ++j)
|
||||
{
|
||||
if(beta==0) r[i*mm+j]=0; else r[i*mm+j] *= beta;
|
||||
for(int k=0; k<kk; ++k) r[i*mm+j] += alpha * a[i*kk+k] * b[j*kk+k];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void auxmatmult<double>(int nn, int mm, int kk, double *r, double *a, double *b, double alpha, double beta)
|
||||
{
|
||||
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, alpha, a, kk, b, kk, beta, r, mm);
|
||||
}
|
||||
|
||||
template<>
|
||||
void auxmatmult<std::complex<double> >(int nn, int mm, int kk, std::complex<double> *r, std::complex<double> *a, std::complex<double> *b, std::complex<double> alpha, std::complex<double> beta)
|
||||
{
|
||||
cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, &alpha, a, kk, b, kk, &beta, r, mm);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
//Conntraction could be implemented without the temporary storage for unwinding, but then we would need
|
||||
//double recursion over indices of both tensors. Hopefully using the matrix multiplication here
|
||||
//makes it also more efficient, even for (anti)symmetric indices
|
||||
//The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices
|
||||
//
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha) const
|
||||
{
|
||||
if(group<0||group>=shape.size()) laerror("wrong group number in contraction");
|
||||
if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction");
|
||||
if(index<0||index>=shape[group].number) laerror("wrong index number in conntraction");
|
||||
if(rhsindex<0||rhsindex>=rhs.shape[rhsgroup].number) laerror("wrong index number in conntraction");
|
||||
if(shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction");
|
||||
if(shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
|
||||
|
||||
Tensor<T> u = unwind_index(group,index);
|
||||
Tensor<T> rhsu = rhs.unwind_index(rhsgroup,rhsindex);
|
||||
|
||||
|
||||
NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2);
|
||||
int ii=0;
|
||||
for(int i=1; i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
|
||||
for(int i=1; i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
|
||||
|
||||
Tensor<T> r(newshape);
|
||||
int nn,mm,kk;
|
||||
kk=u.groupsizes[0];
|
||||
if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction");
|
||||
nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i];
|
||||
mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
|
||||
auxmatmult<T>(nn,mm,kk,&r.data[0],&u.data[0], &rhsu.data[0],alpha);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template class Tensor<double>;
|
||||
|
2
tensor.h
2
tensor.h
@ -179,10 +179,10 @@ public:
|
||||
|
||||
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
|
||||
Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one
|
||||
Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1) const;
|
||||
|
||||
//@@@ general antisymmetrization operator Kucharski style - or that will be left to a code generator?
|
||||
//@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric group - obecne symmetry change krome +1 na -1 vse mozne
|
||||
//@@@contraction
|
||||
};
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user