tensor: implemented merge_indices
This commit is contained in:
126
tensor.cc
126
tensor.cc
@@ -964,6 +964,7 @@ for(int i=0; i<il.size(); ++i)
|
||||
{
|
||||
if(il[i].group<0||il[i].group>=shape.size()) laerror("wrong group number in unwind_indices");
|
||||
if(il[i].index<0||il[i].index>=shape[il[i].group].number) laerror("wrong index number in unwind_indices");
|
||||
for(int j=0; j<i; ++j) if(il[i]==il[j]) laerror("repeated index in the list");
|
||||
}
|
||||
|
||||
//all indices are solo in their groups - permute groups
|
||||
@@ -1242,6 +1243,7 @@ for(int i=0; i<il1.size(); ++i)
|
||||
#ifdef LA_TENSOR_INDEXPOSITION
|
||||
if(rhs1.shape[il1[i].group].upperindex ^ rhs2.shape[il2[i].group].upperindex == false) laerror("can contact only upper with lower index");
|
||||
#endif
|
||||
for(int j=0; j<i; ++j) if(il1[i]==il1[j]||il2[i]==il2[j]) laerror("repeated index in the list");
|
||||
}
|
||||
|
||||
const Tensor<T> u = conjugate1? (rhs1.unwind_indices(il1)).conjugateme() : rhs1.unwind_indices(il1);
|
||||
@@ -1650,6 +1652,130 @@ return ind;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::merge_indices(const INDEXLIST &il, int sym) const
|
||||
{
|
||||
if(il.size()==0) laerror("empty index list for merge_indices");
|
||||
if(il.size()==1) return unwind_index(il[0]); //result should be index group of size 1
|
||||
|
||||
bool samegroup=true;
|
||||
bool isordered=true;
|
||||
for(int i=0; i<il.size(); ++i)
|
||||
{
|
||||
if(il[i].group<0||il[i].group>=shape.size()) laerror("wrong group number in merge_indices");
|
||||
if(il[i].index<0||il[i].index>=shape[il[i].group].number) laerror("wrong index number in merge_indices");
|
||||
for(int j=0; j<i; ++j) if(il[i]==il[j]) laerror("repeated index in the list");
|
||||
#ifdef LA_TENSOR_INDEXPOSITION
|
||||
if(shape[il[0].group].upperindex != shape[il[i].group].upperindex == false) laerror("can merge only within lower or upper separately");
|
||||
#endif
|
||||
if(shape[il[0].group].range != shape[il[i].group].range)
|
||||
{
|
||||
std::cout << "indices "<<il[0]<<" and "<<il[i]<< " have ranges "<<shape[il[0].group].range<< " and "<< shape[il[i].group].range <<" respectively\n";
|
||||
laerror("incompatible range in merge_indices");
|
||||
}
|
||||
if(shape[il[0].group].offset != shape[il[i].group].offset) laerror("incompatible offset in merge_indices");
|
||||
if(il[0].group != il[i].group) samegroup=false;
|
||||
if(il[i].index!=i) isordered=false;
|
||||
}
|
||||
|
||||
if(samegroup && isordered && il.size()==shape[il[0].group].number) return unwind_index_group(il[0].group);
|
||||
|
||||
|
||||
//calculate new shape and flat index permutation
|
||||
NRVec<indexgroup> workshape(shape);
|
||||
workshape.copyonwrite();
|
||||
NRPerm<int> basicperm(rank());
|
||||
|
||||
bitvector was_in_list(rank());
|
||||
was_in_list.clear();
|
||||
for(int i=0; i<il.size(); ++i)
|
||||
{
|
||||
int fp=flatposition(il[i],shape);
|
||||
was_in_list.set(fp);
|
||||
basicperm[i+1] = 1+fp;
|
||||
if( --workshape[il[i].group].number <0) laerror("inconsistent index list with index group size");
|
||||
}
|
||||
int newshapesize=1; //newly created group
|
||||
for(int i=0; i<workshape.size(); ++i) if(workshape[i].number>0) ++newshapesize; //this group survived index removal
|
||||
|
||||
NRVec<indexgroup> newshape(newshapesize);
|
||||
newshape[0].number=il.size();
|
||||
newshape[0].symmetry=sym;
|
||||
newshape[0].offset=shape[il[0].group].offset;
|
||||
newshape[0].range=shape[il[0].group].range;
|
||||
#ifdef LA_TENSOR_INDEXPOSITION
|
||||
newshape[0].upperindex=shape[il[0].group].upperindex;
|
||||
#endif
|
||||
int ii=1;
|
||||
for(int i=0; i<workshape.size(); ++i)
|
||||
if(workshape[i].number>0)
|
||||
newshape[ii++] = workshape[i];
|
||||
int jj=1+il.size();
|
||||
for(int i=0; i<rank(); ++i)
|
||||
if(!was_in_list[i])
|
||||
basicperm[jj++] = 1+i;
|
||||
if(!basicperm.is_valid()) laerror("internal error in merge_indices");
|
||||
|
||||
//std::cout <<"newshape = "<<newshape<<std::endl;
|
||||
//std::cout <<"basicperm = "<<basicperm<<std::endl;
|
||||
|
||||
|
||||
//prepare permutation algebra
|
||||
PermutationAlgebra<int,T> pa;
|
||||
if(sym==0)
|
||||
{
|
||||
pa.resize(1);
|
||||
pa[0].weight=1;
|
||||
pa[0].perm=basicperm;
|
||||
}
|
||||
else
|
||||
{
|
||||
PermutationAlgebra<int,int> sa = sym>0 ? symmetrizer<int>(il.size()) : antisymmetrizer<int>(il.size());
|
||||
//std::cout <<"SA = "<<sa<<std::endl;
|
||||
pa.resize(sa.size());
|
||||
for(int i=0; i<sa.size(); ++i)
|
||||
{
|
||||
pa[i].weight = (T) sa[i].weight;
|
||||
pa[i].perm.resize(rank());
|
||||
for(int j=1; j<=il.size(); ++j) pa[i].perm[j] = basicperm[sa[i].perm[j]];
|
||||
for(int j=il.size()+1; j<=rank(); ++j) pa[i].perm[j] = basicperm[j];
|
||||
}
|
||||
}
|
||||
|
||||
//std::cout <<"Use PA = "<<pa<<std::endl;
|
||||
|
||||
Tensor<T> r(newshape);
|
||||
r.apply_permutation_algebra(*this,pa,false,(T)1/(T)pa.size(),0);
|
||||
return r;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Tensor<T>::canonicalize_shape()
|
||||
{
|
||||
const indexgroup *sh = &(* const_cast<const NRVec<indexgroup> *>(&shape))[0];
|
||||
for(int i=0; i<shape.size(); ++i)
|
||||
{
|
||||
if(sh[i].number==1 && sh[i].symmetry!=0) {shape.copyonwrite(); shape[i].symmetry=0;}
|
||||
if(sh[i].symmetry>1 ) {shape.copyonwrite(); shape[i].symmetry=1;}
|
||||
if(sh[i].symmetry<-1) {shape.copyonwrite(); shape[i].symmetry= -1;}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::ostream & operator<<(std::ostream &s, const INDEX &x)
|
||||
{
|
||||
s<<x.group<<" "<<x.index;
|
||||
return s;
|
||||
}
|
||||
|
||||
std::istream & operator>>(std::istream &s, INDEX &x)
|
||||
{
|
||||
s>>x.group>>x.index;
|
||||
return s;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template class Tensor<double>;
|
||||
template class Tensor<std::complex<double> >;
|
||||
|
||||
Reference in New Issue
Block a user