tensor: addgroupcontractions

This commit is contained in:
2025-11-06 15:07:13 +01:00
parent 5e4c4dddea
commit b3c7d21268
3 changed files with 146 additions and 18 deletions

152
tensor.cc
View File

@@ -753,26 +753,41 @@ FLATINDEX JP = J.permuted(*help_p,false);
template<typename T>
Tensor<T> Tensor<T>::unwind_index_group(int group) const
{
if(group==0) return *this; //is already the least significant group
if(group<0||group>=shape.size()) laerror("wrong group number in unwind_index_group");
NRPerm<int> p(shape.size());
p[1]= 1+group;
int ii=1;
if(ii==1+group) ii++; //skip this
for(int i=2; i<=shape.size(); ++i)
{
p[i]=ii++;
if(ii==1+group) ii++; //skip this
}
if(!p.is_valid()) laerror("internal error in unwind_index");
return permute_index_groups(p);
}
template<typename T>
Tensor<T> Tensor<T>::unwind_index(int group, int index) const
{
if(group<0||group>=shape.size()) laerror("wrong group number in unwind_index");
if(index<0||index>=shape[group].number) laerror("wrong index number in unwind_index");
if(shape[group].number==1) //single index in the group
if(group==0 && index==0 && shape[0].number == 1) return *this;
if(group==0 && index==0 && shape[0].symmetry==0) //formal split of 1 index without data rearrangement
{
if(group==0) return *this; //is already the least significant group
NRPerm<int> p(shape.size());
p[1]= 1+group;
int ii=1;
if(ii==1+group) ii++; //skip this
for(int i=2; i<=shape.size(); ++i)
{
p[i]=ii++;
if(ii==1+group) ii++; //skip this
}
if(!p.is_valid()) laerror("internal error in unwind_index");
return permute_index_groups(p);
Tensor<T> r(*this);
r.split_index_group1(0);
return r;
}
if(shape[group].number==1) return unwind_index_group(group); //single index in the group
//general case - recalculate the shape and allocate the new tensor
NRVec<indexgroup> newshape(shape.size()+1);
@@ -1088,6 +1103,65 @@ cblas_zgemm(CblasRowMajor, CblasNoTrans, (conjugate?CblasConjTrans:CblasTrans),
template<typename T>
void Tensor<T>::addgroupcontraction(const Tensor &rhs1, int group, const Tensor &rhs, int rhsgroup, T alpha, T beta, bool doresize, bool conjugate1, bool conjugate)
{
if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction");
if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction");
if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction");
if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
if(rhs1.shape[group].symmetry != rhs.shape[rhsgroup].symmetry) laerror("incompatible index symmetry in addgroupcontraction");
if(rhs1.shape[group].symmetry == 1) laerror("addgroupcontraction not implemented for symmetric index groups");
#ifdef LA_TENSOR_INDEXPOSITION
if(rhs1.shape[group].upperindex ^ rhs.shape[rhsgroup].upperindex == false) laerror("can contact only upper with lower index");
#endif
const Tensor<T> u = conjugate1? (rhs1.unwind_index_group(group)).conjugate() : rhs1.unwind_index_group(group);
const Tensor<T> rhsu = rhs.unwind_index_group(rhsgroup);
NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2);
int ii=0;
for(int i=1; i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
for(int i=1; i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
NRVec<INDEXNAME> newnames;
if(u.is_named() && rhsu.is_named())
{
for(int i=0; i<u.shape[0].number; ++i)
{
if(u.names[i]!=rhsu.names[i]) laerror("contraction index name mismatch in addgroupcontraction");
}
newnames.resize(u.names.size()+rhsu.names.size()-2*u.shape[0].number);
int ii=0;
for(int i=rhsu.shape[0].number; i<rhsu.names.size(); ++i) newnames[ii++] = rhsu.names[i];
for(int i=u.shape[0].number; i<u.names.size(); ++i) newnames[ii++] = u.names[i];
}
if(doresize)
{
if(beta!= (T)0) laerror("resize in addgroupcontraction requires beta=0");
resize(newshape);
names=newnames;
}
else
{
if(shape!=newshape) laerror("tensor shape mismatch in addgroupcontraction");
if(is_named() && names!=newnames) laerror("remaining index names mismatch in addgroupcontraction");
}
int nn,mm,kk;
kk=u.groupsizes[0];
if(kk!=rhsu.groupsizes[0]) laerror("internal error in addgroupcontraction");
T factor=alpha;
if(u.shape[0].symmetry== -1) factor=alpha*(T)factorial(u.shape[0].number);
if(u.shape[0].symmetry== 1) laerror("addgroupcontraction not implemented for symmetric index groups");
nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i];
mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
data.copyonwrite();
auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],factor,beta,conjugate);
}
//Conntraction could be implemented without the temporary storage for unwinding, but then we would need
@@ -1168,9 +1242,8 @@ for(int i=0; i<il1.size(); ++i)
#endif
}
Tensor<T> u = rhs1.unwind_indices(il1);
if(conjugate1) u.conjugateme();
Tensor<T> rhsu = rhs2.unwind_indices(il2);
const Tensor<T> u = conjugate1? (rhs1.unwind_indices(il1)).conjugateme() : rhs1.unwind_indices(il1);
const Tensor<T> rhsu = rhs2.unwind_indices(il2);
NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2*il1.size());
@@ -1337,6 +1410,36 @@ if(data.size()!=newsize) laerror("internal error in split_index_group");
}
template<typename T>
void Tensor<T>::split_index_group1(int group)
{
if(group<0||group >= shape.size()) laerror("illegal index group number");
if(shape[group].number==1) return; //nothing to split
if(shape[group].symmetry!=0) laerror("only non-symmetric index group can be splitted, use flatten instead");
NRVec<indexgroup> newshape(shape.size()+1);
int gg=0;
for(int g=0; g<shape.size(); ++g)
{
if(g==group)
{
newshape[gg] = shape[group];
newshape[gg].number = 1;
gg++;
newshape[gg] = shape[group];
newshape[gg].number -= 1;
gg++;
}
else
newshape[gg++] = shape[g];
}
shape=newshape;
LA_largeindex newsize = calcsize(); //recalculate auxiliary arrays
if(data.size()!=newsize) laerror("internal error in split_index_group");
}
template<typename T>
void Tensor<T>:: merge_adjacent_index_groups(int groupfrom, int groupto)
{
@@ -1391,6 +1494,23 @@ return r;
}
template<typename T>
T Tensor<T>::dot(const Tensor &rhs) const
{
if(shape!=rhs.shape) laerror("incompatible tensor shapes in dot");
if(is_named() && rhs.is_named() && names!=rhs.names) laerror("incompatible tensor index names in dot");
T factor=1;
for(int i=0; i<shape.size(); ++i)
{
if(shape[i].symmetry==1) laerror("unsupported index group symmetry in dot");
if(shape[i].symmetry== -1) factor *= (T)factorial(shape[i].number);
}
return factor * data.dot(rhs.data);
}
//NOTE: Tucker of rank=2 is inherently inefficient - result is a diagonal tensor stored in full and 2 calls to SVD
//we could avoid the second SVD, but the wasteful storage and reconstruction would remain
//