tensor: addgroupcontractions
This commit is contained in:
parent
5e4c4dddea
commit
b3c7d21268
4
t.cc
4
t.cc
@ -3362,7 +3362,7 @@ for(int i=0; i<n; ++i)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
if(0)
|
if(1)
|
||||||
{
|
{
|
||||||
int n=5;
|
int n=5;
|
||||||
INDEXGROUP ag;
|
INDEXGROUP ag;
|
||||||
@ -3923,7 +3923,7 @@ for(int l=1; l<=n; ++l)
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(1)
|
if(0)
|
||||||
{
|
{
|
||||||
//tensor lhs operator() and signed pointer
|
//tensor lhs operator() and signed pointer
|
||||||
int n;
|
int n;
|
||||||
|
|||||||
152
tensor.cc
152
tensor.cc
@ -753,26 +753,41 @@ FLATINDEX JP = J.permuted(*help_p,false);
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
Tensor<T> Tensor<T>::unwind_index_group(int group) const
|
||||||
|
{
|
||||||
|
if(group==0) return *this; //is already the least significant group
|
||||||
|
if(group<0||group>=shape.size()) laerror("wrong group number in unwind_index_group");
|
||||||
|
|
||||||
|
NRPerm<int> p(shape.size());
|
||||||
|
p[1]= 1+group;
|
||||||
|
int ii=1;
|
||||||
|
if(ii==1+group) ii++; //skip this
|
||||||
|
for(int i=2; i<=shape.size(); ++i)
|
||||||
|
{
|
||||||
|
p[i]=ii++;
|
||||||
|
if(ii==1+group) ii++; //skip this
|
||||||
|
}
|
||||||
|
if(!p.is_valid()) laerror("internal error in unwind_index");
|
||||||
|
return permute_index_groups(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
Tensor<T> Tensor<T>::unwind_index(int group, int index) const
|
Tensor<T> Tensor<T>::unwind_index(int group, int index) const
|
||||||
{
|
{
|
||||||
if(group<0||group>=shape.size()) laerror("wrong group number in unwind_index");
|
if(group<0||group>=shape.size()) laerror("wrong group number in unwind_index");
|
||||||
if(index<0||index>=shape[group].number) laerror("wrong index number in unwind_index");
|
if(index<0||index>=shape[group].number) laerror("wrong index number in unwind_index");
|
||||||
if(shape[group].number==1) //single index in the group
|
|
||||||
|
if(group==0 && index==0 && shape[0].number == 1) return *this;
|
||||||
|
if(group==0 && index==0 && shape[0].symmetry==0) //formal split of 1 index without data rearrangement
|
||||||
{
|
{
|
||||||
if(group==0) return *this; //is already the least significant group
|
Tensor<T> r(*this);
|
||||||
NRPerm<int> p(shape.size());
|
r.split_index_group1(0);
|
||||||
p[1]= 1+group;
|
return r;
|
||||||
int ii=1;
|
|
||||||
if(ii==1+group) ii++; //skip this
|
|
||||||
for(int i=2; i<=shape.size(); ++i)
|
|
||||||
{
|
|
||||||
p[i]=ii++;
|
|
||||||
if(ii==1+group) ii++; //skip this
|
|
||||||
}
|
|
||||||
if(!p.is_valid()) laerror("internal error in unwind_index");
|
|
||||||
return permute_index_groups(p);
|
|
||||||
}
|
}
|
||||||
|
if(shape[group].number==1) return unwind_index_group(group); //single index in the group
|
||||||
|
|
||||||
//general case - recalculate the shape and allocate the new tensor
|
//general case - recalculate the shape and allocate the new tensor
|
||||||
NRVec<indexgroup> newshape(shape.size()+1);
|
NRVec<indexgroup> newshape(shape.size()+1);
|
||||||
@ -1088,6 +1103,65 @@ cblas_zgemm(CblasRowMajor, CblasNoTrans, (conjugate?CblasConjTrans:CblasTrans),
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void Tensor<T>::addgroupcontraction(const Tensor &rhs1, int group, const Tensor &rhs, int rhsgroup, T alpha, T beta, bool doresize, bool conjugate1, bool conjugate)
|
||||||
|
{
|
||||||
|
if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction");
|
||||||
|
if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction");
|
||||||
|
if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction");
|
||||||
|
if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
|
||||||
|
if(rhs1.shape[group].symmetry != rhs.shape[rhsgroup].symmetry) laerror("incompatible index symmetry in addgroupcontraction");
|
||||||
|
if(rhs1.shape[group].symmetry == 1) laerror("addgroupcontraction not implemented for symmetric index groups");
|
||||||
|
#ifdef LA_TENSOR_INDEXPOSITION
|
||||||
|
if(rhs1.shape[group].upperindex ^ rhs.shape[rhsgroup].upperindex == false) laerror("can contact only upper with lower index");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const Tensor<T> u = conjugate1? (rhs1.unwind_index_group(group)).conjugate() : rhs1.unwind_index_group(group);
|
||||||
|
const Tensor<T> rhsu = rhs.unwind_index_group(rhsgroup);
|
||||||
|
|
||||||
|
NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2);
|
||||||
|
int ii=0;
|
||||||
|
for(int i=1; i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
|
||||||
|
for(int i=1; i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
|
||||||
|
|
||||||
|
NRVec<INDEXNAME> newnames;
|
||||||
|
if(u.is_named() && rhsu.is_named())
|
||||||
|
{
|
||||||
|
for(int i=0; i<u.shape[0].number; ++i)
|
||||||
|
{
|
||||||
|
if(u.names[i]!=rhsu.names[i]) laerror("contraction index name mismatch in addgroupcontraction");
|
||||||
|
}
|
||||||
|
newnames.resize(u.names.size()+rhsu.names.size()-2*u.shape[0].number);
|
||||||
|
int ii=0;
|
||||||
|
for(int i=rhsu.shape[0].number; i<rhsu.names.size(); ++i) newnames[ii++] = rhsu.names[i];
|
||||||
|
for(int i=u.shape[0].number; i<u.names.size(); ++i) newnames[ii++] = u.names[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if(doresize)
|
||||||
|
{
|
||||||
|
if(beta!= (T)0) laerror("resize in addgroupcontraction requires beta=0");
|
||||||
|
resize(newshape);
|
||||||
|
names=newnames;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if(shape!=newshape) laerror("tensor shape mismatch in addgroupcontraction");
|
||||||
|
if(is_named() && names!=newnames) laerror("remaining index names mismatch in addgroupcontraction");
|
||||||
|
}
|
||||||
|
|
||||||
|
int nn,mm,kk;
|
||||||
|
kk=u.groupsizes[0];
|
||||||
|
if(kk!=rhsu.groupsizes[0]) laerror("internal error in addgroupcontraction");
|
||||||
|
T factor=alpha;
|
||||||
|
if(u.shape[0].symmetry== -1) factor=alpha*(T)factorial(u.shape[0].number);
|
||||||
|
if(u.shape[0].symmetry== 1) laerror("addgroupcontraction not implemented for symmetric index groups");
|
||||||
|
nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i];
|
||||||
|
mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
|
||||||
|
data.copyonwrite();
|
||||||
|
auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],factor,beta,conjugate);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//Conntraction could be implemented without the temporary storage for unwinding, but then we would need
|
//Conntraction could be implemented without the temporary storage for unwinding, but then we would need
|
||||||
@ -1168,9 +1242,8 @@ for(int i=0; i<il1.size(); ++i)
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor<T> u = rhs1.unwind_indices(il1);
|
const Tensor<T> u = conjugate1? (rhs1.unwind_indices(il1)).conjugateme() : rhs1.unwind_indices(il1);
|
||||||
if(conjugate1) u.conjugateme();
|
const Tensor<T> rhsu = rhs2.unwind_indices(il2);
|
||||||
Tensor<T> rhsu = rhs2.unwind_indices(il2);
|
|
||||||
|
|
||||||
|
|
||||||
NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2*il1.size());
|
NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2*il1.size());
|
||||||
@ -1337,6 +1410,36 @@ if(data.size()!=newsize) laerror("internal error in split_index_group");
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void Tensor<T>::split_index_group1(int group)
|
||||||
|
{
|
||||||
|
if(group<0||group >= shape.size()) laerror("illegal index group number");
|
||||||
|
if(shape[group].number==1) return; //nothing to split
|
||||||
|
if(shape[group].symmetry!=0) laerror("only non-symmetric index group can be splitted, use flatten instead");
|
||||||
|
|
||||||
|
NRVec<indexgroup> newshape(shape.size()+1);
|
||||||
|
int gg=0;
|
||||||
|
for(int g=0; g<shape.size(); ++g)
|
||||||
|
{
|
||||||
|
if(g==group)
|
||||||
|
{
|
||||||
|
newshape[gg] = shape[group];
|
||||||
|
newshape[gg].number = 1;
|
||||||
|
gg++;
|
||||||
|
newshape[gg] = shape[group];
|
||||||
|
newshape[gg].number -= 1;
|
||||||
|
gg++;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
newshape[gg++] = shape[g];
|
||||||
|
}
|
||||||
|
|
||||||
|
shape=newshape;
|
||||||
|
LA_largeindex newsize = calcsize(); //recalculate auxiliary arrays
|
||||||
|
if(data.size()!=newsize) laerror("internal error in split_index_group");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void Tensor<T>:: merge_adjacent_index_groups(int groupfrom, int groupto)
|
void Tensor<T>:: merge_adjacent_index_groups(int groupfrom, int groupto)
|
||||||
{
|
{
|
||||||
@ -1391,6 +1494,23 @@ return r;
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
T Tensor<T>::dot(const Tensor &rhs) const
|
||||||
|
{
|
||||||
|
if(shape!=rhs.shape) laerror("incompatible tensor shapes in dot");
|
||||||
|
if(is_named() && rhs.is_named() && names!=rhs.names) laerror("incompatible tensor index names in dot");
|
||||||
|
T factor=1;
|
||||||
|
for(int i=0; i<shape.size(); ++i)
|
||||||
|
{
|
||||||
|
if(shape[i].symmetry==1) laerror("unsupported index group symmetry in dot");
|
||||||
|
if(shape[i].symmetry== -1) factor *= (T)factorial(shape[i].number);
|
||||||
|
}
|
||||||
|
return factor * data.dot(rhs.data);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
//NOTE: Tucker of rank=2 is inherently inefficient - result is a diagonal tensor stored in full and 2 calls to SVD
|
//NOTE: Tucker of rank=2 is inherently inefficient - result is a diagonal tensor stored in full and 2 calls to SVD
|
||||||
//we could avoid the second SVD, but the wasteful storage and reconstruction would remain
|
//we could avoid the second SVD, but the wasteful storage and reconstruction would remain
|
||||||
//
|
//
|
||||||
|
|||||||
8
tensor.h
8
tensor.h
@ -40,6 +40,7 @@
|
|||||||
#include "miscfunc.h"
|
#include "miscfunc.h"
|
||||||
|
|
||||||
//TODO:
|
//TODO:
|
||||||
|
//@@@@@@how to handle contractions yielding a scalar - special treatment, support special case of rank=0 tensor?
|
||||||
//@@@contraction inside one tensor - compute resulting shape, loopover the shape, create index into the original tensor + loop over the contr. index, do the summation, store result
|
//@@@contraction inside one tensor - compute resulting shape, loopover the shape, create index into the original tensor + loop over the contr. index, do the summation, store result
|
||||||
//@@@ will need to store vector of INDEX to the original tensor for the result's flatindex
|
//@@@ will need to store vector of INDEX to the original tensor for the result's flatindex
|
||||||
//@@@ will not be particularly efficient
|
//@@@ will not be particularly efficient
|
||||||
@ -48,6 +49,7 @@
|
|||||||
//
|
//
|
||||||
//@@@?general permutation of individual indices - check the indices in sym groups remain adjacent, calculate result's shape, loopover the result and permute using unwind_callback
|
//@@@?general permutation of individual indices - check the indices in sym groups remain adjacent, calculate result's shape, loopover the result and permute using unwind_callback
|
||||||
//@@@? apply_permutation_algebra if result should be symmetric/antisymmetric in such a way to compute only the nonredundant part
|
//@@@? apply_permutation_algebra if result should be symmetric/antisymmetric in such a way to compute only the nonredundant part
|
||||||
|
//@@@symetrizace a antisymetrizace skupiny indexu - jak efektivneji nez pres permutationalgebra?
|
||||||
//
|
//
|
||||||
|
|
||||||
|
|
||||||
@ -232,6 +234,7 @@ public:
|
|||||||
typename LA_traits<T>::normtype norm() const {return data.norm();};
|
typename LA_traits<T>::normtype norm() const {return data.norm();};
|
||||||
|
|
||||||
inline Tensor operator*(const Tensor &rhs) const {return Tensor(rhs.shape.concat(shape),data.otimes2vec(rhs.data),rhs.names.concat(names));} //outer product, rhs indices will be the less significant
|
inline Tensor operator*(const Tensor &rhs) const {return Tensor(rhs.shape.concat(shape),data.otimes2vec(rhs.data),rhs.names.concat(names));} //outer product, rhs indices will be the less significant
|
||||||
|
T dot(const Tensor &rhs) const; //scalar product (full contraction), in complex case is automatically conjugated, not for symmetric group indices
|
||||||
|
|
||||||
Tensor& conjugateme() {data.conjugateme(); return *this;};
|
Tensor& conjugateme() {data.conjugateme(); return *this;};
|
||||||
inline Tensor conjugate() const {Tensor r(*this); r.conjugateme(); return r;};
|
inline Tensor conjugate() const {Tensor r(*this); r.conjugateme(); return r;};
|
||||||
@ -275,6 +278,7 @@ public:
|
|||||||
|
|
||||||
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
|
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
|
||||||
|
|
||||||
|
Tensor unwind_index_group(int group) const; //make the index group leftmost (least significant)
|
||||||
Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one (the leftmost one)
|
Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one (the leftmost one)
|
||||||
Tensor unwind_index(const INDEX &I) const {return unwind_index(I.group,I.index);};
|
Tensor unwind_index(const INDEX &I) const {return unwind_index(I.group,I.index);};
|
||||||
Tensor unwind_index(const INDEXNAME &N) const {return unwind_index(findindex(N));};
|
Tensor unwind_index(const INDEXNAME &N) const {return unwind_index(findindex(N));};
|
||||||
@ -295,6 +299,9 @@ public:
|
|||||||
inline Tensor contractions( const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha=1, bool conjugate1=false, bool conjugate2=false) const {Tensor<T> r; r.addcontractions(*this,il1,rhs2,il2,alpha,0,true,conjugate1, conjugate2); return r; };
|
inline Tensor contractions( const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha=1, bool conjugate1=false, bool conjugate2=false) const {Tensor<T> r; r.addcontractions(*this,il1,rhs2,il2,alpha,0,true,conjugate1, conjugate2); return r; };
|
||||||
inline Tensor contractions(const Tensor &rhs2, const NRVec<INDEXNAME> names, T alpha=1, bool conjugate1=false, bool conjugate2=false) const {return contractions(findindexlist(names),rhs2,rhs2.findindexlist(names),alpha,conjugate1,conjugate2); };
|
inline Tensor contractions(const Tensor &rhs2, const NRVec<INDEXNAME> names, T alpha=1, bool conjugate1=false, bool conjugate2=false) const {return contractions(findindexlist(names),rhs2,rhs2.findindexlist(names),alpha,conjugate1,conjugate2); };
|
||||||
|
|
||||||
|
void addgroupcontraction(const Tensor &rhs1, int group, const Tensor &rhs2, int rhsgroup, T alpha=1, T beta=1, bool doresize=false, bool conjugate1=false, bool conjugate=false); //over all indices in a group of same symmetry; rhs1 will have more significant non-contracted indices in the result than rhs2
|
||||||
|
inline Tensor groupcontraction(int group, const Tensor &rhs, int rhsgroup, T alpha=1, bool conjugate1=false, bool conjugate=false) const {Tensor<T> r; r.addgroupcontraction(*this,group,rhs,rhsgroup,alpha,0,true, conjugate1, conjugate); return r; };
|
||||||
|
|
||||||
void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra<int,T> &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list:
|
void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra<int,T> &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list:
|
||||||
void apply_permutation_algebra(const NRVec<Tensor> &rhsvec, const PermutationAlgebra<int,T> &pa, bool inverse=false, T alpha=1, T beta=0); //avoids explicit outer product but not vectorized, rather inefficient
|
void apply_permutation_algebra(const NRVec<Tensor> &rhsvec, const PermutationAlgebra<int,T> &pa, bool inverse=false, T alpha=1, T beta=0); //avoids explicit outer product but not vectorized, rather inefficient
|
||||||
// this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I))
|
// this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I))
|
||||||
@ -304,6 +311,7 @@ public:
|
|||||||
// More efficient would be applying permutation algebra symbolically and efficiently computing term by term
|
// More efficient would be applying permutation algebra symbolically and efficiently computing term by term
|
||||||
|
|
||||||
void split_index_group(int group); //formal in-place split of a non-symmetric index group WITHOUT the need for data reorganization or names rearrangement
|
void split_index_group(int group); //formal in-place split of a non-symmetric index group WITHOUT the need for data reorganization or names rearrangement
|
||||||
|
void split_index_group1(int group); //formal in-place split of the leftmost index in a non-symmetric index group WITHOUT the need for data reorganization or names rearrangement
|
||||||
void merge_adjacent_index_groups(int groupfrom, int groupto); //formal merge of non-symmetric index groups WITHOUT the need for data reorganization or names rearrangement
|
void merge_adjacent_index_groups(int groupfrom, int groupto); //formal merge of non-symmetric index groups WITHOUT the need for data reorganization or names rearrangement
|
||||||
|
|
||||||
Tensor merge_index_groups(const NRVec<int> &groups) const;
|
Tensor merge_index_groups(const NRVec<int> &groups) const;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user