tensor: split and merge index groups

This commit is contained in:
Jiri Pittner 2024-05-06 18:30:01 +02:00
parent 161aa5b1cd
commit 883d201e67
2 changed files with 91 additions and 1 deletions

View File

@ -22,6 +22,7 @@
#include "qsort.h"
#include "miscfunc.h"
#include <complex>
#include "bitvector.h"
namespace LA {
@ -752,6 +753,93 @@ loopover(permutationalgebra_callback);
}
template<typename T>
void Tensor<T>::split_index_group(int group)
{
if(group<0||group >= shape.size()) laerror("illegal index group number");
if(shape[group].number==1) return; //nothing to split
if(shape[group].symmetry!=0) laerror("only non-symmetric index group can be splitted");
NRVec<indexgroup> newshape(shape.size()+shape[group].number-1);
int gg=0;
for(int g=0; g<shape.size(); ++g)
{
if(g==group)
{
for(int i=0; i<shape[group].number; ++i)
{
newshape[gg] = shape[group];
newshape[gg].number = 1;
gg++;
}
}
else
newshape[gg++] = shape[g];
}
shape=newshape;
LA_largeindex newsize = calcsize(); //recalculate auxiliary arrays
if(data.size()!=newsize) laerror("internal error in split_index_group");
}
template<typename T>
void Tensor<T>:: merge_adjacent_index_groups(int groupfrom, int groupto)
{
if(groupfrom<0||groupfrom>= shape.size()) laerror("illegal index group number");
if(groupto<0||groupto>= shape.size()) laerror("illegal index group number");
if(groupfrom==groupto) return;
if(groupfrom>groupto) {int t=groupfrom; groupfrom=groupto; groupto=t;}
int newnumber=0;
for(int g=groupfrom; g<=groupto; ++g)
{
if(shape[g].symmetry!=0) laerror("only non-symmetric index groups can be merged");
if(shape[g].offset!=shape[groupfrom].offset) laerror("incompatible offset in merge_adjacent_index_groups");
if(shape[g].range!=shape[groupfrom].range) laerror("incompatible range in merge_adjacent_index_groups");
newnumber += shape[g].number;
}
NRVec<indexgroup> newshape(shape.size()-(groupto-groupfrom+1)+1);
for(int g=0; g<=groupfrom; ++g) newshape[g]=shape[g];
newshape[groupfrom].number=newnumber;
for(int g=groupfrom+1; g<newshape.size(); ++g) newshape[g]=shape[g+groupto-groupfrom];
shape=newshape;
LA_largeindex newsize = calcsize(); //recalculate auxiliary arrays
if(data.size()!=newsize) laerror("internal error in merge_adjacent_index_groups");
}
template<typename T>
Tensor<T> Tensor<T>::merge_index_groups(const NRVec<int> &groups) const
{
if(groups.size()<=1) return *this;
NRPerm<int> p(shape.size());
int gg=0;
bitvector selected(shape.size());
selected.clear();
for(int g=0; g<groups.size(); ++g)
{
if(groups[g]<0||groups[g]>=shape.size()) laerror("illegal group number in merge_index_groups");
if(selected[g]) laerror("repeated group number in merge_index_groups");
selected.set(g);
p[gg++] = 1+groups[g];
}
for(int g=0; g<shape.size(); ++g)
if(!selected[g])
p[gg++] = 1+g;
if(gg!=shape.size() || !p.is_valid()) laerror("internal error in merge_index_groups");
Tensor<T> r = permute_index_groups(p);
r.merge_adjacent_index_groups(0,groups.size()-1);
return r;
}
template class Tensor<double>;
template class Tensor<std::complex<double> >;
template std::ostream & operator<<(std::ostream &s, const Tensor<double> &x);

View File

@ -190,7 +190,9 @@ public:
void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra<int,T> &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list:
// this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I))
// PermutationAlgebra can represent e.g. general_antisymmetrizer in Kucharski-Bartlett notation
void split_index_group(int group); //formal split of a non-symmetric index group WITHOUT the need for data reorganization
void merge_adjacent_index_groups(int groupfrom, int groupto); //formal merge of non-symmetric index groups WITHOUT the need for data reorganization
Tensor merge_index_groups(const NRVec<int> &groups) const;
//TODO perhaps implement application of a permutation algebra to a product of several tensors
};