working on add_permuted_contractions
This commit is contained in:
87
tensor.cc
87
tensor.cc
@@ -85,7 +85,11 @@ LA_largeindex subindex(int *sign, const INDEXGROUP &g, const NRVec<LA_index> &I)
|
||||
#ifdef DEBUG
|
||||
if(I.size()<=0) laerror("empty index group in subindex");
|
||||
if(g.number!=I.size()) laerror("mismatch in the number of indices in a group");
|
||||
for(int i=0; i<I.size(); ++i) if(I[i]<g.offset || I[i] >= g.offset+g.range) laerror("index out of range in tensor subindex");
|
||||
for(int i=0; i<I.size(); ++i) if(I[i]<g.offset || I[i] >= g.offset+g.range)
|
||||
{
|
||||
std::cout<<"TENSOR INDEX PROBLEM in group " <<g<<" with index "<<I<<std::endl;
|
||||
laerror("index out of range in tensor subindex");
|
||||
}
|
||||
#endif
|
||||
|
||||
switch(I.size()) //a few special cases for efficiency
|
||||
@@ -232,7 +236,7 @@ for(int i=0; i<I.size(); ++i)
|
||||
{
|
||||
if(I[i][j] <shape[i].offset || I[i][j] >= shape[i].offset+shape[i].range)
|
||||
{
|
||||
std::cerr<<"error in index group no. "<<i<<" index no. "<<j<<std::endl;
|
||||
std::cout<<"TENSOR INDEX PROBLEM group no. "<<i<<" index no. "<<j<<" should be between "<<shape[i].offset<<" and "<<shape[i].offset+shape[i].range-1<<std::endl;
|
||||
laerror("tensor index out of range");
|
||||
}
|
||||
}
|
||||
@@ -1883,6 +1887,8 @@ if(is_named()) r.names=names.permuted(basicperm,true);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
void Tensor<T>::canonicalize_shape()
|
||||
{
|
||||
@@ -1999,7 +2005,7 @@ foundname:
|
||||
finished:
|
||||
|
||||
names0= NRVec<INDEXNAME> (names);
|
||||
std::cout <<"names parsed "<<names0;
|
||||
//std::cout <<"names parsed "<<names0;
|
||||
|
||||
groups0.resize(groups.size());
|
||||
int i=0;
|
||||
@@ -2007,7 +2013,7 @@ for(typename std::list<std::list<int> >::const_iterator ii=groups.begin(); ii!=g
|
||||
{
|
||||
groups0[i++] = NRVec_from1<int>(*ii);
|
||||
}
|
||||
std::cout<<"groups parsed "<<groups0;
|
||||
//std::cout<<"groups parsed "<<groups0;
|
||||
|
||||
free(txt);
|
||||
}
|
||||
@@ -2033,10 +2039,14 @@ for(int i=0; i<rhs1.names.size(); ++i) for(int j=0; j<rhs1.names.size(); ++j)
|
||||
}
|
||||
|
||||
INDEXLIST il1(nc),il2(nc);
|
||||
bitvector is_c_index1(rhs1.names.size());
|
||||
bitvector is_c_index2(rhs2.names.size());
|
||||
int ii=0;
|
||||
for(int i=0; i<rhs1.names.size(); ++i) for(int j=0; j<rhs1.names.size(); ++j)
|
||||
for(int i=0; i<rhs1.names.size(); ++i) for(int j=0; j<rhs2.names.size(); ++j)
|
||||
if(rhs1.names[i]==rhs2.names[j])
|
||||
{
|
||||
is_c_index1.set(i);
|
||||
is_c_index2.set(j);
|
||||
il1[ii] = indexposition(i,rhs1.shape);
|
||||
il2[ii] = indexposition(j,rhs2.shape);
|
||||
++ii;
|
||||
@@ -2045,8 +2055,11 @@ for(int i=0; i<rhs1.names.size(); ++i) for(int j=0; j<rhs1.names.size(); ++j)
|
||||
//std::cout<<"contraction list1 = "<<il1<<std::endl;
|
||||
//std::cout<<"contraction list2 = "<<il2<<std::endl;
|
||||
|
||||
Tensor<T> tmp=rhs1.contractions(il1,rhs2,il2,alpha,conjugate1,conjugate2);
|
||||
if(rank()!=tmp.rank()) laerror("rank mismatch in add_permuted_contractions");
|
||||
//make a dry-run to get only the list of names, concatenate names of both tensors unless they are the contraction ones
|
||||
std::list<INDEXNAME> listnames;
|
||||
for(int j=0; j<rhs2.names.size(); ++j) if(!is_c_index2[j]) listnames.push_back(rhs2.names[j]);
|
||||
for(int j=0; j<rhs1.names.size(); ++j) if(!is_c_index1[j]) listnames.push_back(rhs1.names[j]);
|
||||
NRVec<INDEXNAME> tmpnames(listnames);
|
||||
|
||||
//generate the antisymmetrizer, adding also indices not involved as a constant subpermutation
|
||||
//
|
||||
@@ -2056,14 +2069,14 @@ NRVec<NRVec_from1<int> > antigroups;
|
||||
parse_antisymmetrizer(antisymmetrizer,antigroups,antinames);
|
||||
|
||||
//check the names make sense and fill in the possibly missing ones as separate group
|
||||
if(antinames.size()>tmp.names.size()) laerror("too many indices in the antisymmetrizet");
|
||||
bitvector isexplicit(tmp.names.size());
|
||||
if(antinames.size()>tmpnames.size()) laerror("too many indices in the antisymmetrizet");
|
||||
bitvector isexplicit(tmpnames.size());
|
||||
isexplicit.clear();
|
||||
for(int i=0; i<antinames.size(); ++i)
|
||||
{
|
||||
for(int j=0; j<i; ++j) if(antinames[i]==antinames[j]) laerror("repeated names in the antisymmetrizer");
|
||||
for(int j=0; j<tmp.names.size(); ++j)
|
||||
if(antinames[i]==tmp.names[j])
|
||||
for(int j=0; j<tmpnames.size(); ++j)
|
||||
if(antinames[i]==tmpnames[j])
|
||||
{
|
||||
isexplicit.set(j);
|
||||
goto namefound;
|
||||
@@ -2074,15 +2087,15 @@ for(int i=0; i<antinames.size(); ++i)
|
||||
if(isexplicit.population()!=antinames.size()) laerror("internal error in add_permuted_contractions");
|
||||
|
||||
//fill in additional names
|
||||
if(antinames.size()<tmp.names.size())
|
||||
if(antinames.size()<tmpnames.size())
|
||||
{
|
||||
int lastgroup=antigroups.size()-1;
|
||||
int lastclass;
|
||||
if(antigroups.size()==0) lastclass=0;
|
||||
else lastclass=antigroups[antigroups.size()-1][antigroups[antigroups.size()-1].size()];
|
||||
int lastname=antinames.size()-1;
|
||||
antinames.resize(tmp.names.size(),true);
|
||||
antigroups.resize(antigroups.size()+tmp.names.size()-antinames.size(),true);
|
||||
antigroups.resize(antigroups.size()+tmpnames.size()-antinames.size(),true);
|
||||
antinames.resize(tmpnames.size(),true);
|
||||
for(int j=0; j<names.size(); ++j)
|
||||
if(!isexplicit[j])
|
||||
{
|
||||
@@ -2090,37 +2103,49 @@ if(antinames.size()<tmp.names.size())
|
||||
++lastname;
|
||||
++lastgroup;
|
||||
antigroups[lastgroup].resize(1);
|
||||
antigroups[lastgroup][0]=lastclass;
|
||||
antinames[lastname] = tmp.names[j];
|
||||
antigroups[lastgroup][1]=lastclass;
|
||||
antinames[lastname] = tmpnames[j];
|
||||
}
|
||||
}
|
||||
std::cout <<"final antisymmmetrizer names and groups"<<antinames<<antigroups;
|
||||
//std::cout <<"LHS names = "<<names<<std::endl;
|
||||
//std::cout <<"TMP names = "<<tmpnames<<std::endl;
|
||||
|
||||
//prepare the antisymmetrizer
|
||||
PermutationAlgebra<int,int> pa = general_antisymmetrizer(antigroups,-2,true);
|
||||
std::cout <<"initial antisymmetrizer = "<<pa;
|
||||
|
||||
//find permutation between antisym and TMP index order
|
||||
NRPerm<int> antiperm=find_indexperm(antinames,tmp.names);
|
||||
NRPerm<int> antiperm=find_indexperm(antinames,tmpnames);
|
||||
std::cout<<"permutation from rhs to antisymmetrizer = "<<antiperm<<std::endl;
|
||||
|
||||
//@@@conjugate the PA by antiperm or its inverse
|
||||
//conjugate the PA by antiperm or its inverse
|
||||
pa= pa.conjugated_by(antiperm,true);//@@@ or false?
|
||||
|
||||
//@@@recast the PA
|
||||
//find permutation which will bring indices of tmp to the order as in *this: this->names[i] == tmpnames[p[i]]
|
||||
NRPerm<int> basicperm=find_indexperm(names,tmpnames);
|
||||
std::cout <<"permutation from rhs to lhs = "<<basicperm<<std::endl;
|
||||
|
||||
//@@@permutationalgebra::is_identity() a pokd ano, neaplikovat ale vratit tmp tenzor resp. s nim udelat axpy
|
||||
//@@@PermutationAlgebra<int,T> pb = basicperm*pa; //for apply_permutation_algebra(tmp,pb,false,(T)1,beta);
|
||||
PermutationAlgebra<int,T> pb = basicperm.inverse()*pa;
|
||||
//@@@ OR PermutationAlgebra<int,T> pb = basicperm.inverse()*pa; apply_permutation_algebra(tmp,pb,true,(T)1,beta);
|
||||
|
||||
//std::cout <<"LHS names = "<<names<<std::endl;
|
||||
//std::cout <<"TMP names = "<<tmp.names<<std::endl;
|
||||
//save some work if the PA is trivial
|
||||
if(pb.is_identity())
|
||||
{
|
||||
std::cout <<"simplified version\n";
|
||||
addcontractions(rhs1,il1,rhs2,il2,alpha,beta,false,conjugate1,conjugate2);
|
||||
return;
|
||||
}
|
||||
|
||||
//find permutation which will bring indices of tmp to the order as in *this: this->names[i] == tmp.names[p[i]]
|
||||
NRPerm<int> basicperm=find_indexperm(names,tmp.names);
|
||||
std::cout <<"Basic permutation = "<<basicperm<<std::endl;
|
||||
std::cout <<"full version\n";
|
||||
//create an intermediate contracted tensir and the permute it
|
||||
Tensor<T> tmp=rhs1.contractions(il1,rhs2,il2,alpha,conjugate1,conjugate2);
|
||||
if(rank()!=tmp.rank()) laerror("rank mismatch in add_permuted_contractions");
|
||||
if(tmp.names!=tmpnames) laerror("internal error in add_permuted_contractions");
|
||||
|
||||
//@@@probably only onw of those will work
|
||||
PermutationAlgebra<int,T> pb = basicperm*pa;
|
||||
apply_permutation_algebra(tmp,pb,false,(T)1/(T)pb.size(),beta);
|
||||
//equivalently possible (for trivial pa)
|
||||
//PermutationAlgebra<int,T> pb = basicperm.inverse()*pa;
|
||||
//apply_permutation_algebra(tmp,pb,true,(T)1/(T)pb.size(),beta);
|
||||
//@@@apply_permutation_algebra(tmp,pb,false,(T)1,beta);
|
||||
apply_permutation_algebra(tmp,pb,true,(T)1,beta);
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user