tensor: split and merge index groups
This commit is contained in:
parent
161aa5b1cd
commit
883d201e67
88
tensor.cc
88
tensor.cc
@ -22,6 +22,7 @@
|
||||
#include "qsort.h"
|
||||
#include "miscfunc.h"
|
||||
#include <complex>
|
||||
#include "bitvector.h"
|
||||
|
||||
|
||||
namespace LA {
|
||||
@ -752,6 +753,93 @@ loopover(permutationalgebra_callback);
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
void Tensor<T>::split_index_group(int group)
|
||||
{
|
||||
if(group<0||group >= shape.size()) laerror("illegal index group number");
|
||||
if(shape[group].number==1) return; //nothing to split
|
||||
if(shape[group].symmetry!=0) laerror("only non-symmetric index group can be splitted");
|
||||
|
||||
NRVec<indexgroup> newshape(shape.size()+shape[group].number-1);
|
||||
int gg=0;
|
||||
for(int g=0; g<shape.size(); ++g)
|
||||
{
|
||||
if(g==group)
|
||||
{
|
||||
for(int i=0; i<shape[group].number; ++i)
|
||||
{
|
||||
newshape[gg] = shape[group];
|
||||
newshape[gg].number = 1;
|
||||
gg++;
|
||||
}
|
||||
}
|
||||
else
|
||||
newshape[gg++] = shape[g];
|
||||
}
|
||||
|
||||
shape=newshape;
|
||||
LA_largeindex newsize = calcsize(); //recalculate auxiliary arrays
|
||||
if(data.size()!=newsize) laerror("internal error in split_index_group");
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
void Tensor<T>:: merge_adjacent_index_groups(int groupfrom, int groupto)
|
||||
{
|
||||
if(groupfrom<0||groupfrom>= shape.size()) laerror("illegal index group number");
|
||||
if(groupto<0||groupto>= shape.size()) laerror("illegal index group number");
|
||||
if(groupfrom==groupto) return;
|
||||
if(groupfrom>groupto) {int t=groupfrom; groupfrom=groupto; groupto=t;}
|
||||
|
||||
int newnumber=0;
|
||||
for(int g=groupfrom; g<=groupto; ++g)
|
||||
{
|
||||
if(shape[g].symmetry!=0) laerror("only non-symmetric index groups can be merged");
|
||||
if(shape[g].offset!=shape[groupfrom].offset) laerror("incompatible offset in merge_adjacent_index_groups");
|
||||
if(shape[g].range!=shape[groupfrom].range) laerror("incompatible range in merge_adjacent_index_groups");
|
||||
newnumber += shape[g].number;
|
||||
}
|
||||
|
||||
NRVec<indexgroup> newshape(shape.size()-(groupto-groupfrom+1)+1);
|
||||
for(int g=0; g<=groupfrom; ++g) newshape[g]=shape[g];
|
||||
newshape[groupfrom].number=newnumber;
|
||||
for(int g=groupfrom+1; g<newshape.size(); ++g) newshape[g]=shape[g+groupto-groupfrom];
|
||||
|
||||
shape=newshape;
|
||||
LA_largeindex newsize = calcsize(); //recalculate auxiliary arrays
|
||||
if(data.size()!=newsize) laerror("internal error in merge_adjacent_index_groups");
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::merge_index_groups(const NRVec<int> &groups) const
|
||||
{
|
||||
if(groups.size()<=1) return *this;
|
||||
NRPerm<int> p(shape.size());
|
||||
int gg=0;
|
||||
bitvector selected(shape.size());
|
||||
selected.clear();
|
||||
for(int g=0; g<groups.size(); ++g)
|
||||
{
|
||||
if(groups[g]<0||groups[g]>=shape.size()) laerror("illegal group number in merge_index_groups");
|
||||
if(selected[g]) laerror("repeated group number in merge_index_groups");
|
||||
selected.set(g);
|
||||
p[gg++] = 1+groups[g];
|
||||
}
|
||||
for(int g=0; g<shape.size(); ++g)
|
||||
if(!selected[g])
|
||||
p[gg++] = 1+g;
|
||||
if(gg!=shape.size() || !p.is_valid()) laerror("internal error in merge_index_groups");
|
||||
|
||||
Tensor<T> r = permute_index_groups(p);
|
||||
r.merge_adjacent_index_groups(0,groups.size()-1);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
template class Tensor<double>;
|
||||
template class Tensor<std::complex<double> >;
|
||||
template std::ostream & operator<<(std::ostream &s, const Tensor<double> &x);
|
||||
|
4
tensor.h
4
tensor.h
@ -190,7 +190,9 @@ public:
|
||||
void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra<int,T> &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list:
|
||||
// this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I))
|
||||
// PermutationAlgebra can represent e.g. general_antisymmetrizer in Kucharski-Bartlett notation
|
||||
|
||||
void split_index_group(int group); //formal split of a non-symmetric index group WITHOUT the need for data reorganization
|
||||
void merge_adjacent_index_groups(int groupfrom, int groupto); //formal merge of non-symmetric index groups WITHOUT the need for data reorganization
|
||||
Tensor merge_index_groups(const NRVec<int> &groups) const;
|
||||
|
||||
//TODO perhaps implement application of a permutation algebra to a product of several tensors
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user