working on add_permuted_contractions

This commit is contained in:
2025-11-16 14:56:51 +01:00
parent ba5adcd5e6
commit 71c890c39b
3 changed files with 70 additions and 39 deletions

2
t.cc
View File

@@ -4254,7 +4254,7 @@ for(int k=0; k<nn; ++k) for(int l=0; l<nn; ++l) for(int m=0; m<nn; ++m) for(int
Tensor<double> zzz(s); Tensor<double> zzz(s);
for(int k=0; k<nn; ++k) for(int l=0; l<nn; ++l) for(int m=0; m<nn; ++m) for(int n=0; n<nn; ++n) for(int k=0; k<nn; ++k) for(int l=0; l<nn; ++l) for(int m=0; m<nn; ++m) for(int n=0; n<nn; ++n)
{ {
zzz.lhs(k,l,m,n) = 1/3.*(zz(k,l,m,n)-zz(l,k,m,n)+zz(m,k,l,n)); zzz.lhs(k,l,m,n) = (zz(k,l,m,n)-zz(l,k,m,n)+zz(m,k,l,n));
} }
cout <<"Error = "<<(z-zzz).norm()<<endl; cout <<"Error = "<<(z-zzz).norm()<<endl;

View File

@@ -85,7 +85,11 @@ LA_largeindex subindex(int *sign, const INDEXGROUP &g, const NRVec<LA_index> &I)
#ifdef DEBUG #ifdef DEBUG
if(I.size()<=0) laerror("empty index group in subindex"); if(I.size()<=0) laerror("empty index group in subindex");
if(g.number!=I.size()) laerror("mismatch in the number of indices in a group"); if(g.number!=I.size()) laerror("mismatch in the number of indices in a group");
for(int i=0; i<I.size(); ++i) if(I[i]<g.offset || I[i] >= g.offset+g.range) laerror("index out of range in tensor subindex"); for(int i=0; i<I.size(); ++i) if(I[i]<g.offset || I[i] >= g.offset+g.range)
{
std::cout<<"TENSOR INDEX PROBLEM in group " <<g<<" with index "<<I<<std::endl;
laerror("index out of range in tensor subindex");
}
#endif #endif
switch(I.size()) //a few special cases for efficiency switch(I.size()) //a few special cases for efficiency
@@ -232,7 +236,7 @@ for(int i=0; i<I.size(); ++i)
{ {
if(I[i][j] <shape[i].offset || I[i][j] >= shape[i].offset+shape[i].range) if(I[i][j] <shape[i].offset || I[i][j] >= shape[i].offset+shape[i].range)
{ {
std::cerr<<"error in index group no. "<<i<<" index no. "<<j<<std::endl; std::cout<<"TENSOR INDEX PROBLEM group no. "<<i<<" index no. "<<j<<" should be between "<<shape[i].offset<<" and "<<shape[i].offset+shape[i].range-1<<std::endl;
laerror("tensor index out of range"); laerror("tensor index out of range");
} }
} }
@@ -1883,6 +1887,8 @@ if(is_named()) r.names=names.permuted(basicperm,true);
return r; return r;
} }
template<typename T> template<typename T>
void Tensor<T>::canonicalize_shape() void Tensor<T>::canonicalize_shape()
{ {
@@ -1999,7 +2005,7 @@ foundname:
finished: finished:
names0= NRVec<INDEXNAME> (names); names0= NRVec<INDEXNAME> (names);
std::cout <<"names parsed "<<names0; //std::cout <<"names parsed "<<names0;
groups0.resize(groups.size()); groups0.resize(groups.size());
int i=0; int i=0;
@@ -2007,7 +2013,7 @@ for(typename std::list<std::list<int> >::const_iterator ii=groups.begin(); ii!=g
{ {
groups0[i++] = NRVec_from1<int>(*ii); groups0[i++] = NRVec_from1<int>(*ii);
} }
std::cout<<"groups parsed "<<groups0; //std::cout<<"groups parsed "<<groups0;
free(txt); free(txt);
} }
@@ -2033,10 +2039,14 @@ for(int i=0; i<rhs1.names.size(); ++i) for(int j=0; j<rhs1.names.size(); ++j)
} }
INDEXLIST il1(nc),il2(nc); INDEXLIST il1(nc),il2(nc);
bitvector is_c_index1(rhs1.names.size());
bitvector is_c_index2(rhs2.names.size());
int ii=0; int ii=0;
for(int i=0; i<rhs1.names.size(); ++i) for(int j=0; j<rhs1.names.size(); ++j) for(int i=0; i<rhs1.names.size(); ++i) for(int j=0; j<rhs2.names.size(); ++j)
if(rhs1.names[i]==rhs2.names[j]) if(rhs1.names[i]==rhs2.names[j])
{ {
is_c_index1.set(i);
is_c_index2.set(j);
il1[ii] = indexposition(i,rhs1.shape); il1[ii] = indexposition(i,rhs1.shape);
il2[ii] = indexposition(j,rhs2.shape); il2[ii] = indexposition(j,rhs2.shape);
++ii; ++ii;
@@ -2045,8 +2055,11 @@ for(int i=0; i<rhs1.names.size(); ++i) for(int j=0; j<rhs1.names.size(); ++j)
//std::cout<<"contraction list1 = "<<il1<<std::endl; //std::cout<<"contraction list1 = "<<il1<<std::endl;
//std::cout<<"contraction list2 = "<<il2<<std::endl; //std::cout<<"contraction list2 = "<<il2<<std::endl;
Tensor<T> tmp=rhs1.contractions(il1,rhs2,il2,alpha,conjugate1,conjugate2); //make a dry-run to get only the list of names, concatenate names of both tensors unless they are the contraction ones
if(rank()!=tmp.rank()) laerror("rank mismatch in add_permuted_contractions"); std::list<INDEXNAME> listnames;
for(int j=0; j<rhs2.names.size(); ++j) if(!is_c_index2[j]) listnames.push_back(rhs2.names[j]);
for(int j=0; j<rhs1.names.size(); ++j) if(!is_c_index1[j]) listnames.push_back(rhs1.names[j]);
NRVec<INDEXNAME> tmpnames(listnames);
//generate the antisymmetrizer, adding also indices not involved as a constant subpermutation //generate the antisymmetrizer, adding also indices not involved as a constant subpermutation
// //
@@ -2056,14 +2069,14 @@ NRVec<NRVec_from1<int> > antigroups;
parse_antisymmetrizer(antisymmetrizer,antigroups,antinames); parse_antisymmetrizer(antisymmetrizer,antigroups,antinames);
//check the names make sense and fill in the possibly missing ones as separate group //check the names make sense and fill in the possibly missing ones as separate group
if(antinames.size()>tmp.names.size()) laerror("too many indices in the antisymmetrizet"); if(antinames.size()>tmpnames.size()) laerror("too many indices in the antisymmetrizet");
bitvector isexplicit(tmp.names.size()); bitvector isexplicit(tmpnames.size());
isexplicit.clear(); isexplicit.clear();
for(int i=0; i<antinames.size(); ++i) for(int i=0; i<antinames.size(); ++i)
{ {
for(int j=0; j<i; ++j) if(antinames[i]==antinames[j]) laerror("repeated names in the antisymmetrizer"); for(int j=0; j<i; ++j) if(antinames[i]==antinames[j]) laerror("repeated names in the antisymmetrizer");
for(int j=0; j<tmp.names.size(); ++j) for(int j=0; j<tmpnames.size(); ++j)
if(antinames[i]==tmp.names[j]) if(antinames[i]==tmpnames[j])
{ {
isexplicit.set(j); isexplicit.set(j);
goto namefound; goto namefound;
@@ -2074,15 +2087,15 @@ for(int i=0; i<antinames.size(); ++i)
if(isexplicit.population()!=antinames.size()) laerror("internal error in add_permuted_contractions"); if(isexplicit.population()!=antinames.size()) laerror("internal error in add_permuted_contractions");
//fill in additional names //fill in additional names
if(antinames.size()<tmp.names.size()) if(antinames.size()<tmpnames.size())
{ {
int lastgroup=antigroups.size()-1; int lastgroup=antigroups.size()-1;
int lastclass; int lastclass;
if(antigroups.size()==0) lastclass=0; if(antigroups.size()==0) lastclass=0;
else lastclass=antigroups[antigroups.size()-1][antigroups[antigroups.size()-1].size()]; else lastclass=antigroups[antigroups.size()-1][antigroups[antigroups.size()-1].size()];
int lastname=antinames.size()-1; int lastname=antinames.size()-1;
antinames.resize(tmp.names.size(),true); antigroups.resize(antigroups.size()+tmpnames.size()-antinames.size(),true);
antigroups.resize(antigroups.size()+tmp.names.size()-antinames.size(),true); antinames.resize(tmpnames.size(),true);
for(int j=0; j<names.size(); ++j) for(int j=0; j<names.size(); ++j)
if(!isexplicit[j]) if(!isexplicit[j])
{ {
@@ -2090,37 +2103,49 @@ if(antinames.size()<tmp.names.size())
++lastname; ++lastname;
++lastgroup; ++lastgroup;
antigroups[lastgroup].resize(1); antigroups[lastgroup].resize(1);
antigroups[lastgroup][0]=lastclass; antigroups[lastgroup][1]=lastclass;
antinames[lastname] = tmp.names[j]; antinames[lastname] = tmpnames[j];
} }
} }
std::cout <<"final antisymmmetrizer names and groups"<<antinames<<antigroups; std::cout <<"final antisymmmetrizer names and groups"<<antinames<<antigroups;
//std::cout <<"LHS names = "<<names<<std::endl;
//std::cout <<"TMP names = "<<tmpnames<<std::endl;
//prepare the antisymmetrizer //prepare the antisymmetrizer
PermutationAlgebra<int,int> pa = general_antisymmetrizer(antigroups,-2,true); PermutationAlgebra<int,int> pa = general_antisymmetrizer(antigroups,-2,true);
std::cout <<"initial antisymmetrizer = "<<pa;
//find permutation between antisym and TMP index order //find permutation between antisym and TMP index order
NRPerm<int> antiperm=find_indexperm(antinames,tmp.names); NRPerm<int> antiperm=find_indexperm(antinames,tmpnames);
std::cout<<"permutation from rhs to antisymmetrizer = "<<antiperm<<std::endl;
//@@@conjugate the PA by antiperm or its inverse //conjugate the PA by antiperm or its inverse
pa= pa.conjugated_by(antiperm,true);//@@@ or false?
//@@@recast the PA //find permutation which will bring indices of tmp to the order as in *this: this->names[i] == tmpnames[p[i]]
NRPerm<int> basicperm=find_indexperm(names,tmpnames);
std::cout <<"permutation from rhs to lhs = "<<basicperm<<std::endl;
//@@@permutationalgebra::is_identity() a pokd ano, neaplikovat ale vratit tmp tenzor resp. s nim udelat axpy //@@@PermutationAlgebra<int,T> pb = basicperm*pa; //for apply_permutation_algebra(tmp,pb,false,(T)1,beta);
PermutationAlgebra<int,T> pb = basicperm.inverse()*pa;
//@@@ OR PermutationAlgebra<int,T> pb = basicperm.inverse()*pa; apply_permutation_algebra(tmp,pb,true,(T)1,beta);
//std::cout <<"LHS names = "<<names<<std::endl; //save some work if the PA is trivial
//std::cout <<"TMP names = "<<tmp.names<<std::endl; if(pb.is_identity())
{
std::cout <<"simplified version\n";
addcontractions(rhs1,il1,rhs2,il2,alpha,beta,false,conjugate1,conjugate2);
return;
}
//find permutation which will bring indices of tmp to the order as in *this: this->names[i] == tmp.names[p[i]] std::cout <<"full version\n";
NRPerm<int> basicperm=find_indexperm(names,tmp.names); //create an intermediate contracted tensir and the permute it
std::cout <<"Basic permutation = "<<basicperm<<std::endl; Tensor<T> tmp=rhs1.contractions(il1,rhs2,il2,alpha,conjugate1,conjugate2);
if(rank()!=tmp.rank()) laerror("rank mismatch in add_permuted_contractions");
if(tmp.names!=tmpnames) laerror("internal error in add_permuted_contractions");
//@@@probably only onw of those will work //@@@apply_permutation_algebra(tmp,pb,false,(T)1,beta);
PermutationAlgebra<int,T> pb = basicperm*pa; apply_permutation_algebra(tmp,pb,true,(T)1,beta);
apply_permutation_algebra(tmp,pb,false,(T)1/(T)pb.size(),beta);
//equivalently possible (for trivial pa)
//PermutationAlgebra<int,T> pb = basicperm.inverse()*pa;
//apply_permutation_algebra(tmp,pb,true,(T)1/(T)pb.size(),beta);
} }

View File

@@ -267,23 +267,29 @@ public:
NRVec<INDEX> findindexlist(const NRVec<INDEXNAME> &names) const; NRVec<INDEX> findindexlist(const NRVec<INDEXNAME> &names) const;
void renameindex(const INDEXNAME namfrom, const INDEXNAME nameto) {int i=findflatindex(namfrom); names[i]=nameto;}; void renameindex(const INDEXNAME namfrom, const INDEXNAME nameto) {int i=findflatindex(namfrom); names[i]=nameto;};
inline Tensor& operator+=(const Tensor &rhs) Tensor& operator+=(const Tensor &rhs)
{ {
#ifdef DEBUG
if(shape!=rhs.shape) laerror("incompatible tensors for operation"); if(shape!=rhs.shape) laerror("incompatible tensors for operation");
#endif if(is_named() && rhs.is_named() && names!=rhs.names) laerror("incompatible names for operation");
data+=rhs.data; data+=rhs.data;
return *this; return *this;
} }
inline Tensor& operator-=(const Tensor &rhs) Tensor& operator-=(const Tensor &rhs)
{ {
#ifdef DEBUG
if(shape!=rhs.shape) laerror("incompatible tensors for operation"); if(shape!=rhs.shape) laerror("incompatible tensors for operation");
#endif if(is_named() && rhs.is_named() && names!=rhs.names) laerror("incompatible names for operation");
data-=rhs.data; data-=rhs.data;
return *this; return *this;
} }
Tensor& axpy(const T alpha, const Tensor &rhs)
{
if(shape!=rhs.shape) laerror("incompatible tensors for operation");
if(is_named() && rhs.is_named() && names!=rhs.names) laerror("incompatible names for operation");
data.axpy(alpha,rhs.data);
return *this;
}
inline Tensor operator+(const Tensor &rhs) const {Tensor r(*this); r+=rhs; return r;}; inline Tensor operator+(const Tensor &rhs) const {Tensor r(*this); r+=rhs; return r;};
inline Tensor operator-(const Tensor &rhs) const {Tensor r(*this); r-=rhs; return r;}; inline Tensor operator-(const Tensor &rhs) const {Tensor r(*this); r-=rhs; return r;};
@@ -613,7 +619,7 @@ NRMat<T> mat(t.data,range*(range-1)/2,range*(range-1)/2);
f=fourindex_dense<antisymtwoelectronrealdirac,T,I>(range,NRSMat<T>(mat)); //symmetrize mat f=fourindex_dense<antisymtwoelectronrealdirac,T,I>(range,NRSMat<T>(mat)); //symmetrize mat
} }
//@@@formal permutation of names inside a sym/antisy group (with possible sign change)
template <typename T> template <typename T>