tensor: contractions over severeal indices implemented
This commit is contained in:
parent
883d201e67
commit
0b91e88dca
37
t.cc
37
t.cc
@ -3316,7 +3316,12 @@ bg.range=n;
|
|||||||
Tensor<double> b(bg);
|
Tensor<double> b(bg);
|
||||||
b.randomize(1.);
|
b.randomize(1.);
|
||||||
|
|
||||||
Tensor<double> cc = a.contraction(0,0,b,0,1);
|
INDEXLIST il1(1);
|
||||||
|
il1[0]={0,0};
|
||||||
|
INDEXLIST il2(1);
|
||||||
|
il2[0]={0,1};
|
||||||
|
Tensor<double> cc = a.contractions(il1,b,il2);
|
||||||
|
//Tensor<double> cc = a.contraction(0,0,b,0,1);
|
||||||
cout <<cc;
|
cout <<cc;
|
||||||
|
|
||||||
INDEXGROUP cga;
|
INDEXGROUP cga;
|
||||||
@ -3344,7 +3349,7 @@ for(int i=0; i<n; ++i)
|
|||||||
{
|
{
|
||||||
for(int p=0; p<n; ++p)
|
for(int p=0; p<n; ++p)
|
||||||
c.lhs(m,l,k,j,i) += a(p,i,j,k) * b(m,p,l);
|
c.lhs(m,l,k,j,i) += a(p,i,j,k) * b(m,p,l);
|
||||||
if(abs(c(m,l,k,j,i)-cc(m,l,k,j,i))>1e-13) laerror("internal error in conntraction");
|
if(abs(c(m,l,k,j,i)-cc(m,l,k,j,i))>1e-13) laerror("internal error in contraction");
|
||||||
}
|
}
|
||||||
|
|
||||||
//cout <<c;
|
//cout <<c;
|
||||||
@ -3352,4 +3357,32 @@ for(int i=0; i<n; ++i)
|
|||||||
|
|
||||||
//test Tensor apply_permutation_algebra
|
//test Tensor apply_permutation_algebra
|
||||||
|
|
||||||
|
//test unwind_indices
|
||||||
|
if(0)
|
||||||
|
{
|
||||||
|
int n=5;
|
||||||
|
INDEXGROUP g;
|
||||||
|
g.number=4;
|
||||||
|
g.symmetry= -1;
|
||||||
|
g.offset=0;
|
||||||
|
g.range=n;
|
||||||
|
|
||||||
|
Tensor<double> e(g);
|
||||||
|
e.randomize(1.);
|
||||||
|
INDEXLIST il(2);
|
||||||
|
il[0]= {0,1};
|
||||||
|
il[1]= {0,3};
|
||||||
|
Tensor<double> eu = e.unwind_indices(il);
|
||||||
|
|
||||||
|
for(int i=0; i<n; ++i)
|
||||||
|
for(int j=0; j<n; ++j)
|
||||||
|
for(int k=0; k<n; ++k)
|
||||||
|
for(int l=0; l<n; ++l)
|
||||||
|
{
|
||||||
|
if(e(i,j,k,l)!=eu(j,l,i,k)) laerror("error in unwind_indces");
|
||||||
|
}
|
||||||
|
cout <<e;
|
||||||
|
cout <<eu;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
188
tensor.cc
188
tensor.cc
@ -567,6 +567,22 @@ for(int i=0; i<I.size(); ++i)
|
|||||||
return J;
|
return J;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int flatposition(const INDEX &i, const NRVec<indexgroup> &shape)
|
||||||
|
{
|
||||||
|
int ii=0;
|
||||||
|
for(int g=0; g<i.group; ++g) ii+= shape[g].number;
|
||||||
|
ii += i.index;
|
||||||
|
return ii;
|
||||||
|
}
|
||||||
|
|
||||||
|
int flatposition(int group, int index, const NRVec<indexgroup> &shape)
|
||||||
|
{
|
||||||
|
int ii=0;
|
||||||
|
for(int g=0; g<group; ++g) ii+= shape[g].number;
|
||||||
|
ii += index;
|
||||||
|
return ii;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void unwind_callback(const SUPERINDEX &I, T *v)
|
static void unwind_callback(const SUPERINDEX &I, T *v)
|
||||||
@ -648,7 +664,119 @@ return r;
|
|||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0, bool conjugate=false) //R(nn,mm) = A * B^T
|
Tensor<T> Tensor<T>::unwind_indices(const INDEXLIST &il) const
|
||||||
|
{
|
||||||
|
if(il.size()==0) return *this;
|
||||||
|
if(il.size()==1) return unwind_index(il[0].group,il[0].index);
|
||||||
|
|
||||||
|
for(int i=0; i<il.size(); ++i)
|
||||||
|
{
|
||||||
|
if(il[i].group<0||il[i].group>=shape.size()) laerror("wrong group number in unwind_indices");
|
||||||
|
if(il[i].index<0||il[i].index>=shape[il[i].group].number) laerror("wrong index number in unwind_indices");
|
||||||
|
}
|
||||||
|
|
||||||
|
//all indices are solo in their groups - permute groups
|
||||||
|
bool sologroups=true;
|
||||||
|
int nonsolo=0;
|
||||||
|
for(int i=0; i<il.size(); ++i)
|
||||||
|
if(shape[il[i].group].number!=1) {sologroups=false; ++nonsolo;}
|
||||||
|
if(sologroups)
|
||||||
|
{
|
||||||
|
NRPerm<int> p(shape.size());
|
||||||
|
bitvector waslisted(shape.size());
|
||||||
|
waslisted.clear();
|
||||||
|
for(int i=0; i<il.size(); ++i)
|
||||||
|
{
|
||||||
|
p[1+i] = 1+il[i].group;
|
||||||
|
waslisted.set(il[i].group);
|
||||||
|
}
|
||||||
|
int ii=il.size();
|
||||||
|
for(int i=0; i<shape.size(); ++i)
|
||||||
|
{
|
||||||
|
if(!waslisted[i])
|
||||||
|
{
|
||||||
|
waslisted.set(i);
|
||||||
|
p[1+ii] = 1+i;
|
||||||
|
ii++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!p.is_valid()) laerror("internal error in unwind_indices");
|
||||||
|
if(p.is_identity()) return *this;
|
||||||
|
else return permute_index_groups(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//general case - recalculate the shape and allocate the new tensor
|
||||||
|
NRVec<indexgroup> oldshape(shape);
|
||||||
|
oldshape.copyonwrite();
|
||||||
|
NRVec<indexgroup> newshape(shape.size()+nonsolo);
|
||||||
|
|
||||||
|
//first the unwound indices as solo groups
|
||||||
|
for(int i=0; i<il.size(); ++i)
|
||||||
|
{
|
||||||
|
newshape[i].number=1;
|
||||||
|
newshape[i].symmetry=0;
|
||||||
|
newshape[i].range=shape[il[i].group].range;
|
||||||
|
#ifndef LA_TENSOR_ZERO_OFFSET
|
||||||
|
newshape[i].offset = shape[il[i].group].offset;
|
||||||
|
#endif
|
||||||
|
oldshape[il[i].group].number --;
|
||||||
|
}
|
||||||
|
|
||||||
|
//then the remaining groups with one index removed, if nonempty
|
||||||
|
int ii=il.size();
|
||||||
|
for(int i=0; i<oldshape.size(); ++i)
|
||||||
|
if(oldshape[i].number>0)
|
||||||
|
{
|
||||||
|
newshape[ii++] = oldshape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor<T> r(newshape);
|
||||||
|
if(r.rank()!=rank()) laerror("internal error 2 in unwind_indces");
|
||||||
|
|
||||||
|
//compute the corresponding permutation of FLATINDEX for use in the callback
|
||||||
|
NRPerm<int> indexperm(rank());
|
||||||
|
bitvector waslisted(rank());
|
||||||
|
waslisted.clear();
|
||||||
|
//first unwound indices
|
||||||
|
ii=0;
|
||||||
|
for(int i=0; i<il.size(); ++i)
|
||||||
|
{
|
||||||
|
int pos= flatposition(il[i],shape);
|
||||||
|
indexperm[1+ii] = 1+pos;
|
||||||
|
waslisted.set(pos);
|
||||||
|
++ii;
|
||||||
|
}
|
||||||
|
|
||||||
|
//the remaining index groups
|
||||||
|
for(int g=0; g<shape.size(); ++g)
|
||||||
|
for(int i=0; i<shape[g].number; ++i)
|
||||||
|
{
|
||||||
|
int pos= flatposition(g,i,shape);
|
||||||
|
if(!waslisted[pos])
|
||||||
|
{
|
||||||
|
waslisted.set(pos);
|
||||||
|
indexperm[1+ii] = 1+pos;
|
||||||
|
++ii;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!indexperm.is_valid())
|
||||||
|
{
|
||||||
|
std::cout << "indexperm = "<<indexperm<<std::endl;
|
||||||
|
laerror("internal error 3 in unwind_indices");
|
||||||
|
}
|
||||||
|
|
||||||
|
//loop recursively and do the unwinding
|
||||||
|
help_tt<T> = this;
|
||||||
|
help_p = &indexperm;
|
||||||
|
r.loopover(unwind_callback);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0, bool conjugate=false) //R(nn,mm) = A(nn,kk) * B^T(mm,kk)
|
||||||
{
|
{
|
||||||
for(int i=0; i<nn; ++i) for(int j=0; j<mm; ++j)
|
for(int i=0; i<nn; ++i) for(int j=0; j<mm; ++j)
|
||||||
{
|
{
|
||||||
@ -680,7 +808,7 @@ cblas_zgemm(CblasRowMajor, CblasNoTrans, (conjugate?CblasConjTrans:CblasTrans),
|
|||||||
//The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices
|
//The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices
|
||||||
//
|
//
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void Tensor<T>::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize, bool conjugate)
|
void Tensor<T>::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize, bool conjugate1, bool conjugate)
|
||||||
{
|
{
|
||||||
if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction");
|
if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction");
|
||||||
if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction");
|
if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction");
|
||||||
@ -690,6 +818,7 @@ if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible
|
|||||||
if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
|
if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
|
||||||
|
|
||||||
Tensor<T> u = rhs1.unwind_index(group,index);
|
Tensor<T> u = rhs1.unwind_index(group,index);
|
||||||
|
if(conjugate1) u.conjugateme();
|
||||||
Tensor<T> rhsu = rhs.unwind_index(rhsgroup,rhsindex);
|
Tensor<T> rhsu = rhs.unwind_index(rhsgroup,rhsindex);
|
||||||
|
|
||||||
|
|
||||||
@ -709,7 +838,7 @@ else
|
|||||||
}
|
}
|
||||||
int nn,mm,kk;
|
int nn,mm,kk;
|
||||||
kk=u.groupsizes[0];
|
kk=u.groupsizes[0];
|
||||||
if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction");
|
if(kk!=rhsu.groupsizes[0]) laerror("internal error in addcontraction");
|
||||||
nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i];
|
nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i];
|
||||||
mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
|
mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
|
||||||
data.copyonwrite();
|
data.copyonwrite();
|
||||||
@ -717,6 +846,59 @@ auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta,conjugate);
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void Tensor<T>::addcontractions(const Tensor &rhs1, const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha, T beta, bool doresize, bool conjugate1, bool conjugate2)
|
||||||
|
{
|
||||||
|
if(il1.size()==0) laerror("empty contraction - outer product not implemented");
|
||||||
|
if(il1.size()!=il2.size()) laerror("mismatch in index lists in addcontractions");
|
||||||
|
for(int i=0; i<il1.size(); ++i)
|
||||||
|
{
|
||||||
|
if(il1[i].group<0||il1[i].group>=rhs1.shape.size()) laerror("wrong group1 number in contractions");
|
||||||
|
if(il2[i].group<0||il2[i].group>=rhs2.shape.size()) laerror("wrong group2 number in contractions");
|
||||||
|
if(il1[i].index<0||il1[i].index>=rhs1.shape[il1[i].group].number) laerror("wrong index1 number in conntractions");
|
||||||
|
if(il2[i].index<0||il2[i].index>=rhs2.shape[il2[i].group].number) laerror("wrong index2 number in conntractions");
|
||||||
|
if(rhs1.shape[il1[i].group].offset != rhs2.shape[il2[i].group].offset) laerror("incompatible index offset in contractions");
|
||||||
|
if(rhs1.shape[il1[i].group].range != rhs2.shape[il2[i].group].range) laerror("incompatible index range in contractions");
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor<T> u = rhs1.unwind_indices(il1);
|
||||||
|
if(conjugate1) u.conjugateme();
|
||||||
|
Tensor<T> rhsu = rhs2.unwind_indices(il2);
|
||||||
|
|
||||||
|
|
||||||
|
NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2*il1.size());
|
||||||
|
int ii=0;
|
||||||
|
for(int i=il1.size(); i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
|
||||||
|
for(int i=il1.size(); i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
|
||||||
|
|
||||||
|
if(doresize)
|
||||||
|
{
|
||||||
|
if(beta!= (T)0) laerror("resize in addcontractions requires beta=0");
|
||||||
|
resize(newshape);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if(shape!=newshape) laerror("tensor shape mismatch in addcontraction");
|
||||||
|
}
|
||||||
|
int nn,mm,kk;
|
||||||
|
kk=1;
|
||||||
|
int kk2=1;
|
||||||
|
for(int i=0; i<il1.size(); ++i)
|
||||||
|
{
|
||||||
|
kk *= u.groupsizes[i];
|
||||||
|
kk2 *= rhsu.groupsizes[i];
|
||||||
|
}
|
||||||
|
if(kk!=kk2) laerror("internal error in addcontractions");
|
||||||
|
|
||||||
|
nn=1; for(int i=il1.size(); i<u.shape.size(); ++i) nn*= u.groupsizes[i];
|
||||||
|
mm=1; for(int i=il1.size(); i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
|
||||||
|
data.copyonwrite();
|
||||||
|
auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta,conjugate2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static const PermutationAlgebra<int,T> *help_pa;
|
static const PermutationAlgebra<int,T> *help_pa;
|
||||||
|
|
||||||
|
20
tensor.h
20
tensor.h
@ -36,6 +36,9 @@
|
|||||||
#include "smat.h"
|
#include "smat.h"
|
||||||
#include "miscfunc.h"
|
#include "miscfunc.h"
|
||||||
|
|
||||||
|
//@@@todo - outer product
|
||||||
|
//@@@permutation of individual indices??? how to treat the symmetry groups
|
||||||
|
//@@@todo - index names and contraction by named index list
|
||||||
|
|
||||||
namespace LA {
|
namespace LA {
|
||||||
|
|
||||||
@ -98,6 +101,15 @@ class LA_traits<indexgroup> {
|
|||||||
typedef NRVec<LA_index> FLATINDEX; //all indices but in a single vector
|
typedef NRVec<LA_index> FLATINDEX; //all indices but in a single vector
|
||||||
typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP structure
|
typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP structure
|
||||||
typedef NRVec<LA_largeindex> GROUPINDEX; //set of indices in the symmetry groups
|
typedef NRVec<LA_largeindex> GROUPINDEX; //set of indices in the symmetry groups
|
||||||
|
struct INDEX
|
||||||
|
{
|
||||||
|
int group;
|
||||||
|
int index;
|
||||||
|
};
|
||||||
|
typedef NRVec<INDEX> INDEXLIST; //collection of several indices
|
||||||
|
|
||||||
|
int flatposition(const INDEX &i, const NRVec<indexgroup> &shape); //position of that index in FLATINDEX
|
||||||
|
int flatposition(const INDEX &i, const NRVec<indexgroup> &shape);
|
||||||
|
|
||||||
FLATINDEX superindex2flat(const SUPERINDEX &I);
|
FLATINDEX superindex2flat(const SUPERINDEX &I);
|
||||||
|
|
||||||
@ -184,8 +196,12 @@ public:
|
|||||||
|
|
||||||
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
|
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
|
||||||
Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one
|
Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one
|
||||||
void addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, T beta=1, bool doresize=false, bool conjugate=false);
|
Tensor unwind_indices(const INDEXLIST &il) const; //the same for a list of indices
|
||||||
inline Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, bool conjugate=false) const {Tensor<T> r; r.addcontraction(*this,group,index,rhs,rhsgroup,rhsindex,alpha,0,true, conjugate); return r; }
|
void addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs2, int rhsgroup, int rhsindex, T alpha=1, T beta=1, bool doresize=false, bool conjugate1=false, bool conjugate=false); //rhs1 will have more significant non-contracted indices in the result than rhs2
|
||||||
|
inline Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, bool conjugate1=false, bool conjugate=false) const {Tensor<T> r; r.addcontraction(*this,group,index,rhs,rhsgroup,rhsindex,alpha,0,true, conjugate1, conjugate); return r; };
|
||||||
|
|
||||||
|
void addcontractions(const Tensor &rhs1, const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha=1, T beta=1, bool doresize=false, bool conjugate1=false, bool conjugate2=false);
|
||||||
|
inline Tensor contractions( const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha=1, bool conjugate1=false, bool conjugate2=false) const {Tensor<T> r; r.addcontractions(*this,il1,rhs2,il2,alpha,0,true,conjugate1, conjugate2); return r; };
|
||||||
|
|
||||||
void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra<int,T> &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list:
|
void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra<int,T> &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list:
|
||||||
// this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I))
|
// this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I))
|
||||||
|
1
vec.cc
1
vec.cc
@ -909,6 +909,7 @@ void NRVec<T>::storesubvector(const NRVec<int> &selection, const NRVec &rhs)
|
|||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
template<typename T>
|
template<typename T>
|
||||||
NRVec<T>& NRVec<T>::conjugateme() {
|
NRVec<T>& NRVec<T>::conjugateme() {
|
||||||
|
copyonwrite();
|
||||||
#ifdef CUDALA
|
#ifdef CUDALA
|
||||||
if(location != cpu) laerror("general conjugation only on CPU");
|
if(location != cpu) laerror("general conjugation only on CPU");
|
||||||
#endif
|
#endif
|
||||||
|
Loading…
Reference in New Issue
Block a user