working on add_permuted_contractions

This commit is contained in:
2025-11-14 16:16:22 +01:00
parent 89cc0c5b1e
commit 3f7586378d
3 changed files with 127 additions and 7 deletions

View File

@@ -1198,7 +1198,7 @@ if(rhsindex<0||rhsindex>=rhs.shape[rhsgroup].number) laerror("wrong index numbe
if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction");
if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
#ifdef LA_TENSOR_INDEXPOSITION
if(rhs1.shape[group].upperindex ^ rhs.shape[rhsgroup].upperindex == false) laerror("can contact only upper with lower index");
if(rhs1.shape[group].upperindex ^ rhs.shape[rhsgroup].upperindex == false) laerror("can contract only upper with lower index");
#endif
const Tensor<T> u = conjugate1? (rhs1.unwind_index(group,index)).conjugate() : rhs1.unwind_index(group,index);
@@ -1458,9 +1458,11 @@ if(rank()!=rhs.rank()) laerror("rank mismatch in apply_permutation_algebra");
if(rhs.is_named())
{
NRVec<INDEXNAME> namperm = rhs.names.permuted(pa[0].perm,inverse);
NRVec<INDEXNAME> namperm = rhs.names.permuted(pa[0].perm,!inverse);
if(is_named())
{
std::cout <<"LHS names = "<<names<<std::endl;
std::cout <<"permuted RHS names = "<<namperm<<std::endl;
if(names!=namperm) laerror("inconsistent index names in apply_permutation_algebra");
}
else
@@ -1492,7 +1494,7 @@ if(allnamed)
{
NRVec<INDEXNAME> allrhsnames=rhsvec[0].names;
for(int i=1; i<rhsvec.size(); ++i) allrhsnames.concatme(rhsvec[i].names);
NRVec<INDEXNAME> namperm = allrhsnames.permuted(pa[0].perm,inverse);
NRVec<INDEXNAME> namperm = allrhsnames.permuted(pa[0].perm,!inverse);
if(is_named())
{
if(names!=namperm) laerror("inconsistent index names in apply_permutation_algebra");
@@ -1905,6 +1907,74 @@ std::istream & operator>>(std::istream &s, INDEX &x)
s>>x.group>>x.index;
return s;
}
template<typename T>
void Tensor<T>::add_permuted_contractions(const char *antisymmetrizer, const Tensor &rhs1, const Tensor &rhs2, T alpha, T beta, bool conjugate1, bool conjugate2)
{
if(!rhs1.is_uniquely_named()||!rhs2.is_uniquely_named()|| !is_uniquely_named()) laerror("tensors must have unique named indices in add_permuted_contractions");
//find contraction indices
int nc=0;
for(int i=0; i<rhs1.names.size(); ++i) for(int j=0; j<rhs1.names.size(); ++j)
if(rhs1.names[i]==rhs2.names[j])
{
//std::cout << "found contraction "<<nc<<" th. index = "<<rhs1.names[i]<<std::endl;
#ifdef LA_TENSOR_INDEXPOSITION
int group1=indexposition(i,rhs1.shape);
int group2=indexposition(j,rhs2.shape);
if(rhs1.shape[group1].upperindex ^ rhs2.shape[group2].upperindex == false) laerror("can contract only upper with lower index");
#endif
++nc;
}
INDEXLIST il1(nc),il2(nc);
int ii=0;
for(int i=0; i<rhs1.names.size(); ++i) for(int j=0; j<rhs1.names.size(); ++j)
if(rhs1.names[i]==rhs2.names[j])
{
il1[ii] = indexposition(i,rhs1.shape);
il2[ii] = indexposition(j,rhs2.shape);
++ii;
}
//std::cout<<"contraction list1 = "<<il1<<std::endl;
//std::cout<<"contraction list2 = "<<il2<<std::endl;
Tensor<T> tmp=rhs1.contractions(il1,rhs2,il2,alpha,conjugate1,conjugate2);
if(rank()!=tmp.rank()) laerror("rank mismatch in add_permuted_contractions");
//generate the antisymmetrizer, adding also indices not involved as a constant subpermutation
//@@@
PermutationAlgebra<int,T> pa(1);
pa[0].weight=1;
pa[0].perm.resize(rank());
pa[0].perm.identity();
//std::cout <<"LHS names = "<<names<<std::endl;
//std::cout <<"TMP names = "<<tmp.names<<std::endl;
//find permutation which will bring indices of tmp to the order as in *this: this->names[i] == tmp.names[p[i]]
NRPerm<int> basicperm(rank());
basicperm.clear();
for(int i=0; i<rank(); ++i)
{
for(int j=0; j<tmp.rank(); ++j)
if((* const_cast<const NRVec<INDEXNAME> *>(&names))[i] == tmp.names[j])
{
basicperm[i+1]=j+1;
break;
}
}
if(!basicperm.is_valid()) laerror("indices mismatch between lhs and rhs in add_permuted_contractions");
std::cout <<"Basic permutation = "<<basicperm<<std::endl;
PermutationAlgebra<int,T> pb = basicperm*pa;
apply_permutation_algebra(tmp,pb,false,(T)1/(T)pb.size(),beta);
//equivalently possible
//PermutationAlgebra<int,T> pb = basicperm.inverse()*pa;
//apply_permutation_algebra(tmp,pb,true,(T)1/(T)pb.size(),beta);
}