tensor: addgroupcontractions
This commit is contained in:
152
tensor.cc
152
tensor.cc
@@ -753,26 +753,41 @@ FLATINDEX JP = J.permuted(*help_p,false);
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::unwind_index_group(int group) const
|
||||
{
|
||||
if(group==0) return *this; //is already the least significant group
|
||||
if(group<0||group>=shape.size()) laerror("wrong group number in unwind_index_group");
|
||||
|
||||
NRPerm<int> p(shape.size());
|
||||
p[1]= 1+group;
|
||||
int ii=1;
|
||||
if(ii==1+group) ii++; //skip this
|
||||
for(int i=2; i<=shape.size(); ++i)
|
||||
{
|
||||
p[i]=ii++;
|
||||
if(ii==1+group) ii++; //skip this
|
||||
}
|
||||
if(!p.is_valid()) laerror("internal error in unwind_index");
|
||||
return permute_index_groups(p);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::unwind_index(int group, int index) const
|
||||
{
|
||||
if(group<0||group>=shape.size()) laerror("wrong group number in unwind_index");
|
||||
if(index<0||index>=shape[group].number) laerror("wrong index number in unwind_index");
|
||||
if(shape[group].number==1) //single index in the group
|
||||
|
||||
if(group==0 && index==0 && shape[0].number == 1) return *this;
|
||||
if(group==0 && index==0 && shape[0].symmetry==0) //formal split of 1 index without data rearrangement
|
||||
{
|
||||
if(group==0) return *this; //is already the least significant group
|
||||
NRPerm<int> p(shape.size());
|
||||
p[1]= 1+group;
|
||||
int ii=1;
|
||||
if(ii==1+group) ii++; //skip this
|
||||
for(int i=2; i<=shape.size(); ++i)
|
||||
{
|
||||
p[i]=ii++;
|
||||
if(ii==1+group) ii++; //skip this
|
||||
}
|
||||
if(!p.is_valid()) laerror("internal error in unwind_index");
|
||||
return permute_index_groups(p);
|
||||
Tensor<T> r(*this);
|
||||
r.split_index_group1(0);
|
||||
return r;
|
||||
}
|
||||
if(shape[group].number==1) return unwind_index_group(group); //single index in the group
|
||||
|
||||
//general case - recalculate the shape and allocate the new tensor
|
||||
NRVec<indexgroup> newshape(shape.size()+1);
|
||||
@@ -1088,6 +1103,65 @@ cblas_zgemm(CblasRowMajor, CblasNoTrans, (conjugate?CblasConjTrans:CblasTrans),
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
void Tensor<T>::addgroupcontraction(const Tensor &rhs1, int group, const Tensor &rhs, int rhsgroup, T alpha, T beta, bool doresize, bool conjugate1, bool conjugate)
|
||||
{
|
||||
if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction");
|
||||
if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction");
|
||||
if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction");
|
||||
if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
|
||||
if(rhs1.shape[group].symmetry != rhs.shape[rhsgroup].symmetry) laerror("incompatible index symmetry in addgroupcontraction");
|
||||
if(rhs1.shape[group].symmetry == 1) laerror("addgroupcontraction not implemented for symmetric index groups");
|
||||
#ifdef LA_TENSOR_INDEXPOSITION
|
||||
if(rhs1.shape[group].upperindex ^ rhs.shape[rhsgroup].upperindex == false) laerror("can contact only upper with lower index");
|
||||
#endif
|
||||
|
||||
const Tensor<T> u = conjugate1? (rhs1.unwind_index_group(group)).conjugate() : rhs1.unwind_index_group(group);
|
||||
const Tensor<T> rhsu = rhs.unwind_index_group(rhsgroup);
|
||||
|
||||
NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2);
|
||||
int ii=0;
|
||||
for(int i=1; i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
|
||||
for(int i=1; i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
|
||||
|
||||
NRVec<INDEXNAME> newnames;
|
||||
if(u.is_named() && rhsu.is_named())
|
||||
{
|
||||
for(int i=0; i<u.shape[0].number; ++i)
|
||||
{
|
||||
if(u.names[i]!=rhsu.names[i]) laerror("contraction index name mismatch in addgroupcontraction");
|
||||
}
|
||||
newnames.resize(u.names.size()+rhsu.names.size()-2*u.shape[0].number);
|
||||
int ii=0;
|
||||
for(int i=rhsu.shape[0].number; i<rhsu.names.size(); ++i) newnames[ii++] = rhsu.names[i];
|
||||
for(int i=u.shape[0].number; i<u.names.size(); ++i) newnames[ii++] = u.names[i];
|
||||
}
|
||||
|
||||
if(doresize)
|
||||
{
|
||||
if(beta!= (T)0) laerror("resize in addgroupcontraction requires beta=0");
|
||||
resize(newshape);
|
||||
names=newnames;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(shape!=newshape) laerror("tensor shape mismatch in addgroupcontraction");
|
||||
if(is_named() && names!=newnames) laerror("remaining index names mismatch in addgroupcontraction");
|
||||
}
|
||||
|
||||
int nn,mm,kk;
|
||||
kk=u.groupsizes[0];
|
||||
if(kk!=rhsu.groupsizes[0]) laerror("internal error in addgroupcontraction");
|
||||
T factor=alpha;
|
||||
if(u.shape[0].symmetry== -1) factor=alpha*(T)factorial(u.shape[0].number);
|
||||
if(u.shape[0].symmetry== 1) laerror("addgroupcontraction not implemented for symmetric index groups");
|
||||
nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i];
|
||||
mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
|
||||
data.copyonwrite();
|
||||
auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],factor,beta,conjugate);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//Conntraction could be implemented without the temporary storage for unwinding, but then we would need
|
||||
@@ -1168,9 +1242,8 @@ for(int i=0; i<il1.size(); ++i)
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor<T> u = rhs1.unwind_indices(il1);
|
||||
if(conjugate1) u.conjugateme();
|
||||
Tensor<T> rhsu = rhs2.unwind_indices(il2);
|
||||
const Tensor<T> u = conjugate1? (rhs1.unwind_indices(il1)).conjugateme() : rhs1.unwind_indices(il1);
|
||||
const Tensor<T> rhsu = rhs2.unwind_indices(il2);
|
||||
|
||||
|
||||
NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2*il1.size());
|
||||
@@ -1337,6 +1410,36 @@ if(data.size()!=newsize) laerror("internal error in split_index_group");
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
void Tensor<T>::split_index_group1(int group)
|
||||
{
|
||||
if(group<0||group >= shape.size()) laerror("illegal index group number");
|
||||
if(shape[group].number==1) return; //nothing to split
|
||||
if(shape[group].symmetry!=0) laerror("only non-symmetric index group can be splitted, use flatten instead");
|
||||
|
||||
NRVec<indexgroup> newshape(shape.size()+1);
|
||||
int gg=0;
|
||||
for(int g=0; g<shape.size(); ++g)
|
||||
{
|
||||
if(g==group)
|
||||
{
|
||||
newshape[gg] = shape[group];
|
||||
newshape[gg].number = 1;
|
||||
gg++;
|
||||
newshape[gg] = shape[group];
|
||||
newshape[gg].number -= 1;
|
||||
gg++;
|
||||
}
|
||||
else
|
||||
newshape[gg++] = shape[g];
|
||||
}
|
||||
|
||||
shape=newshape;
|
||||
LA_largeindex newsize = calcsize(); //recalculate auxiliary arrays
|
||||
if(data.size()!=newsize) laerror("internal error in split_index_group");
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
void Tensor<T>:: merge_adjacent_index_groups(int groupfrom, int groupto)
|
||||
{
|
||||
@@ -1391,6 +1494,23 @@ return r;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
T Tensor<T>::dot(const Tensor &rhs) const
|
||||
{
|
||||
if(shape!=rhs.shape) laerror("incompatible tensor shapes in dot");
|
||||
if(is_named() && rhs.is_named() && names!=rhs.names) laerror("incompatible tensor index names in dot");
|
||||
T factor=1;
|
||||
for(int i=0; i<shape.size(); ++i)
|
||||
{
|
||||
if(shape[i].symmetry==1) laerror("unsupported index group symmetry in dot");
|
||||
if(shape[i].symmetry== -1) factor *= (T)factorial(shape[i].number);
|
||||
}
|
||||
return factor * data.dot(rhs.data);
|
||||
}
|
||||
|
||||
|
||||
|
||||
//NOTE: Tucker of rank=2 is inherently inefficient - result is a diagonal tensor stored in full and 2 calls to SVD
|
||||
//we could avoid the second SVD, but the wasteful storage and reconstruction would remain
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user