tensor: contractions over severeal indices implemented
This commit is contained in:
parent
883d201e67
commit
0b91e88dca
37
t.cc
37
t.cc
@ -3316,7 +3316,12 @@ bg.range=n;
|
||||
Tensor<double> b(bg);
|
||||
b.randomize(1.);
|
||||
|
||||
Tensor<double> cc = a.contraction(0,0,b,0,1);
|
||||
INDEXLIST il1(1);
|
||||
il1[0]={0,0};
|
||||
INDEXLIST il2(1);
|
||||
il2[0]={0,1};
|
||||
Tensor<double> cc = a.contractions(il1,b,il2);
|
||||
//Tensor<double> cc = a.contraction(0,0,b,0,1);
|
||||
cout <<cc;
|
||||
|
||||
INDEXGROUP cga;
|
||||
@ -3344,7 +3349,7 @@ for(int i=0; i<n; ++i)
|
||||
{
|
||||
for(int p=0; p<n; ++p)
|
||||
c.lhs(m,l,k,j,i) += a(p,i,j,k) * b(m,p,l);
|
||||
if(abs(c(m,l,k,j,i)-cc(m,l,k,j,i))>1e-13) laerror("internal error in conntraction");
|
||||
if(abs(c(m,l,k,j,i)-cc(m,l,k,j,i))>1e-13) laerror("internal error in contraction");
|
||||
}
|
||||
|
||||
//cout <<c;
|
||||
@ -3352,4 +3357,32 @@ for(int i=0; i<n; ++i)
|
||||
|
||||
//test Tensor apply_permutation_algebra
|
||||
|
||||
//test unwind_indices
|
||||
if(0)
|
||||
{
|
||||
int n=5;
|
||||
INDEXGROUP g;
|
||||
g.number=4;
|
||||
g.symmetry= -1;
|
||||
g.offset=0;
|
||||
g.range=n;
|
||||
|
||||
Tensor<double> e(g);
|
||||
e.randomize(1.);
|
||||
INDEXLIST il(2);
|
||||
il[0]= {0,1};
|
||||
il[1]= {0,3};
|
||||
Tensor<double> eu = e.unwind_indices(il);
|
||||
|
||||
for(int i=0; i<n; ++i)
|
||||
for(int j=0; j<n; ++j)
|
||||
for(int k=0; k<n; ++k)
|
||||
for(int l=0; l<n; ++l)
|
||||
{
|
||||
if(e(i,j,k,l)!=eu(j,l,i,k)) laerror("error in unwind_indces");
|
||||
}
|
||||
cout <<e;
|
||||
cout <<eu;
|
||||
}
|
||||
|
||||
}
|
||||
|
188
tensor.cc
188
tensor.cc
@ -567,6 +567,22 @@ for(int i=0; i<I.size(); ++i)
|
||||
return J;
|
||||
}
|
||||
|
||||
int flatposition(const INDEX &i, const NRVec<indexgroup> &shape)
|
||||
{
|
||||
int ii=0;
|
||||
for(int g=0; g<i.group; ++g) ii+= shape[g].number;
|
||||
ii += i.index;
|
||||
return ii;
|
||||
}
|
||||
|
||||
int flatposition(int group, int index, const NRVec<indexgroup> &shape)
|
||||
{
|
||||
int ii=0;
|
||||
for(int g=0; g<group; ++g) ii+= shape[g].number;
|
||||
ii += index;
|
||||
return ii;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static void unwind_callback(const SUPERINDEX &I, T *v)
|
||||
@ -648,7 +664,119 @@ return r;
|
||||
|
||||
|
||||
template<typename T>
|
||||
static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0, bool conjugate=false) //R(nn,mm) = A * B^T
|
||||
Tensor<T> Tensor<T>::unwind_indices(const INDEXLIST &il) const
|
||||
{
|
||||
if(il.size()==0) return *this;
|
||||
if(il.size()==1) return unwind_index(il[0].group,il[0].index);
|
||||
|
||||
for(int i=0; i<il.size(); ++i)
|
||||
{
|
||||
if(il[i].group<0||il[i].group>=shape.size()) laerror("wrong group number in unwind_indices");
|
||||
if(il[i].index<0||il[i].index>=shape[il[i].group].number) laerror("wrong index number in unwind_indices");
|
||||
}
|
||||
|
||||
//all indices are solo in their groups - permute groups
|
||||
bool sologroups=true;
|
||||
int nonsolo=0;
|
||||
for(int i=0; i<il.size(); ++i)
|
||||
if(shape[il[i].group].number!=1) {sologroups=false; ++nonsolo;}
|
||||
if(sologroups)
|
||||
{
|
||||
NRPerm<int> p(shape.size());
|
||||
bitvector waslisted(shape.size());
|
||||
waslisted.clear();
|
||||
for(int i=0; i<il.size(); ++i)
|
||||
{
|
||||
p[1+i] = 1+il[i].group;
|
||||
waslisted.set(il[i].group);
|
||||
}
|
||||
int ii=il.size();
|
||||
for(int i=0; i<shape.size(); ++i)
|
||||
{
|
||||
if(!waslisted[i])
|
||||
{
|
||||
waslisted.set(i);
|
||||
p[1+ii] = 1+i;
|
||||
ii++;
|
||||
}
|
||||
}
|
||||
|
||||
if(!p.is_valid()) laerror("internal error in unwind_indices");
|
||||
if(p.is_identity()) return *this;
|
||||
else return permute_index_groups(p);
|
||||
}
|
||||
|
||||
|
||||
//general case - recalculate the shape and allocate the new tensor
|
||||
NRVec<indexgroup> oldshape(shape);
|
||||
oldshape.copyonwrite();
|
||||
NRVec<indexgroup> newshape(shape.size()+nonsolo);
|
||||
|
||||
//first the unwound indices as solo groups
|
||||
for(int i=0; i<il.size(); ++i)
|
||||
{
|
||||
newshape[i].number=1;
|
||||
newshape[i].symmetry=0;
|
||||
newshape[i].range=shape[il[i].group].range;
|
||||
#ifndef LA_TENSOR_ZERO_OFFSET
|
||||
newshape[i].offset = shape[il[i].group].offset;
|
||||
#endif
|
||||
oldshape[il[i].group].number --;
|
||||
}
|
||||
|
||||
//then the remaining groups with one index removed, if nonempty
|
||||
int ii=il.size();
|
||||
for(int i=0; i<oldshape.size(); ++i)
|
||||
if(oldshape[i].number>0)
|
||||
{
|
||||
newshape[ii++] = oldshape[i];
|
||||
}
|
||||
|
||||
Tensor<T> r(newshape);
|
||||
if(r.rank()!=rank()) laerror("internal error 2 in unwind_indces");
|
||||
|
||||
//compute the corresponding permutation of FLATINDEX for use in the callback
|
||||
NRPerm<int> indexperm(rank());
|
||||
bitvector waslisted(rank());
|
||||
waslisted.clear();
|
||||
//first unwound indices
|
||||
ii=0;
|
||||
for(int i=0; i<il.size(); ++i)
|
||||
{
|
||||
int pos= flatposition(il[i],shape);
|
||||
indexperm[1+ii] = 1+pos;
|
||||
waslisted.set(pos);
|
||||
++ii;
|
||||
}
|
||||
|
||||
//the remaining index groups
|
||||
for(int g=0; g<shape.size(); ++g)
|
||||
for(int i=0; i<shape[g].number; ++i)
|
||||
{
|
||||
int pos= flatposition(g,i,shape);
|
||||
if(!waslisted[pos])
|
||||
{
|
||||
waslisted.set(pos);
|
||||
indexperm[1+ii] = 1+pos;
|
||||
++ii;
|
||||
}
|
||||
}
|
||||
|
||||
if(!indexperm.is_valid())
|
||||
{
|
||||
std::cout << "indexperm = "<<indexperm<<std::endl;
|
||||
laerror("internal error 3 in unwind_indices");
|
||||
}
|
||||
|
||||
//loop recursively and do the unwinding
|
||||
help_tt<T> = this;
|
||||
help_p = &indexperm;
|
||||
r.loopover(unwind_callback);
|
||||
return r;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0, bool conjugate=false) //R(nn,mm) = A(nn,kk) * B^T(mm,kk)
|
||||
{
|
||||
for(int i=0; i<nn; ++i) for(int j=0; j<mm; ++j)
|
||||
{
|
||||
@ -680,7 +808,7 @@ cblas_zgemm(CblasRowMajor, CblasNoTrans, (conjugate?CblasConjTrans:CblasTrans),
|
||||
//The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices
|
||||
//
|
||||
template<typename T>
|
||||
void Tensor<T>::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize, bool conjugate)
|
||||
void Tensor<T>::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize, bool conjugate1, bool conjugate)
|
||||
{
|
||||
if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction");
|
||||
if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction");
|
||||
@ -690,6 +818,7 @@ if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible
|
||||
if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
|
||||
|
||||
Tensor<T> u = rhs1.unwind_index(group,index);
|
||||
if(conjugate1) u.conjugateme();
|
||||
Tensor<T> rhsu = rhs.unwind_index(rhsgroup,rhsindex);
|
||||
|
||||
|
||||
@ -709,7 +838,7 @@ else
|
||||
}
|
||||
int nn,mm,kk;
|
||||
kk=u.groupsizes[0];
|
||||
if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction");
|
||||
if(kk!=rhsu.groupsizes[0]) laerror("internal error in addcontraction");
|
||||
nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i];
|
||||
mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
|
||||
data.copyonwrite();
|
||||
@ -717,6 +846,59 @@ auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta,conjugate);
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
void Tensor<T>::addcontractions(const Tensor &rhs1, const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha, T beta, bool doresize, bool conjugate1, bool conjugate2)
|
||||
{
|
||||
if(il1.size()==0) laerror("empty contraction - outer product not implemented");
|
||||
if(il1.size()!=il2.size()) laerror("mismatch in index lists in addcontractions");
|
||||
for(int i=0; i<il1.size(); ++i)
|
||||
{
|
||||
if(il1[i].group<0||il1[i].group>=rhs1.shape.size()) laerror("wrong group1 number in contractions");
|
||||
if(il2[i].group<0||il2[i].group>=rhs2.shape.size()) laerror("wrong group2 number in contractions");
|
||||
if(il1[i].index<0||il1[i].index>=rhs1.shape[il1[i].group].number) laerror("wrong index1 number in conntractions");
|
||||
if(il2[i].index<0||il2[i].index>=rhs2.shape[il2[i].group].number) laerror("wrong index2 number in conntractions");
|
||||
if(rhs1.shape[il1[i].group].offset != rhs2.shape[il2[i].group].offset) laerror("incompatible index offset in contractions");
|
||||
if(rhs1.shape[il1[i].group].range != rhs2.shape[il2[i].group].range) laerror("incompatible index range in contractions");
|
||||
}
|
||||
|
||||
Tensor<T> u = rhs1.unwind_indices(il1);
|
||||
if(conjugate1) u.conjugateme();
|
||||
Tensor<T> rhsu = rhs2.unwind_indices(il2);
|
||||
|
||||
|
||||
NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2*il1.size());
|
||||
int ii=0;
|
||||
for(int i=il1.size(); i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
|
||||
for(int i=il1.size(); i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
|
||||
|
||||
if(doresize)
|
||||
{
|
||||
if(beta!= (T)0) laerror("resize in addcontractions requires beta=0");
|
||||
resize(newshape);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(shape!=newshape) laerror("tensor shape mismatch in addcontraction");
|
||||
}
|
||||
int nn,mm,kk;
|
||||
kk=1;
|
||||
int kk2=1;
|
||||
for(int i=0; i<il1.size(); ++i)
|
||||
{
|
||||
kk *= u.groupsizes[i];
|
||||
kk2 *= rhsu.groupsizes[i];
|
||||
}
|
||||
if(kk!=kk2) laerror("internal error in addcontractions");
|
||||
|
||||
nn=1; for(int i=il1.size(); i<u.shape.size(); ++i) nn*= u.groupsizes[i];
|
||||
mm=1; for(int i=il1.size(); i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
|
||||
data.copyonwrite();
|
||||
auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta,conjugate2);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
static const PermutationAlgebra<int,T> *help_pa;
|
||||
|
||||
|
20
tensor.h
20
tensor.h
@ -36,6 +36,9 @@
|
||||
#include "smat.h"
|
||||
#include "miscfunc.h"
|
||||
|
||||
//@@@todo - outer product
|
||||
//@@@permutation of individual indices??? how to treat the symmetry groups
|
||||
//@@@todo - index names and contraction by named index list
|
||||
|
||||
namespace LA {
|
||||
|
||||
@ -98,6 +101,15 @@ class LA_traits<indexgroup> {
|
||||
typedef NRVec<LA_index> FLATINDEX; //all indices but in a single vector
|
||||
typedef NRVec<NRVec<LA_index> > SUPERINDEX; //all indices in the INDEXGROUP structure
|
||||
typedef NRVec<LA_largeindex> GROUPINDEX; //set of indices in the symmetry groups
|
||||
struct INDEX
|
||||
{
|
||||
int group;
|
||||
int index;
|
||||
};
|
||||
typedef NRVec<INDEX> INDEXLIST; //collection of several indices
|
||||
|
||||
int flatposition(const INDEX &i, const NRVec<indexgroup> &shape); //position of that index in FLATINDEX
|
||||
int flatposition(const INDEX &i, const NRVec<indexgroup> &shape);
|
||||
|
||||
FLATINDEX superindex2flat(const SUPERINDEX &I);
|
||||
|
||||
@ -184,8 +196,12 @@ public:
|
||||
|
||||
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
|
||||
Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one
|
||||
void addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, T beta=1, bool doresize=false, bool conjugate=false);
|
||||
inline Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, bool conjugate=false) const {Tensor<T> r; r.addcontraction(*this,group,index,rhs,rhsgroup,rhsindex,alpha,0,true, conjugate); return r; }
|
||||
Tensor unwind_indices(const INDEXLIST &il) const; //the same for a list of indices
|
||||
void addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs2, int rhsgroup, int rhsindex, T alpha=1, T beta=1, bool doresize=false, bool conjugate1=false, bool conjugate=false); //rhs1 will have more significant non-contracted indices in the result than rhs2
|
||||
inline Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, bool conjugate1=false, bool conjugate=false) const {Tensor<T> r; r.addcontraction(*this,group,index,rhs,rhsgroup,rhsindex,alpha,0,true, conjugate1, conjugate); return r; };
|
||||
|
||||
void addcontractions(const Tensor &rhs1, const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha=1, T beta=1, bool doresize=false, bool conjugate1=false, bool conjugate2=false);
|
||||
inline Tensor contractions( const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha=1, bool conjugate1=false, bool conjugate2=false) const {Tensor<T> r; r.addcontractions(*this,il1,rhs2,il2,alpha,0,true,conjugate1, conjugate2); return r; };
|
||||
|
||||
void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra<int,T> &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list:
|
||||
// this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I))
|
||||
|
1
vec.cc
1
vec.cc
@ -909,6 +909,7 @@ void NRVec<T>::storesubvector(const NRVec<int> &selection, const NRVec &rhs)
|
||||
******************************************************************************/
|
||||
template<typename T>
|
||||
NRVec<T>& NRVec<T>::conjugateme() {
|
||||
copyonwrite();
|
||||
#ifdef CUDALA
|
||||
if(location != cpu) laerror("general conjugation only on CPU");
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user