tensor: split and merge index groups
This commit is contained in:
parent
161aa5b1cd
commit
883d201e67
88
tensor.cc
88
tensor.cc
@ -22,6 +22,7 @@
|
|||||||
#include "qsort.h"
|
#include "qsort.h"
|
||||||
#include "miscfunc.h"
|
#include "miscfunc.h"
|
||||||
#include <complex>
|
#include <complex>
|
||||||
|
#include "bitvector.h"
|
||||||
|
|
||||||
|
|
||||||
namespace LA {
|
namespace LA {
|
||||||
@ -752,6 +753,93 @@ loopover(permutationalgebra_callback);
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void Tensor<T>::split_index_group(int group)
|
||||||
|
{
|
||||||
|
if(group<0||group >= shape.size()) laerror("illegal index group number");
|
||||||
|
if(shape[group].number==1) return; //nothing to split
|
||||||
|
if(shape[group].symmetry!=0) laerror("only non-symmetric index group can be splitted");
|
||||||
|
|
||||||
|
NRVec<indexgroup> newshape(shape.size()+shape[group].number-1);
|
||||||
|
int gg=0;
|
||||||
|
for(int g=0; g<shape.size(); ++g)
|
||||||
|
{
|
||||||
|
if(g==group)
|
||||||
|
{
|
||||||
|
for(int i=0; i<shape[group].number; ++i)
|
||||||
|
{
|
||||||
|
newshape[gg] = shape[group];
|
||||||
|
newshape[gg].number = 1;
|
||||||
|
gg++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
newshape[gg++] = shape[g];
|
||||||
|
}
|
||||||
|
|
||||||
|
shape=newshape;
|
||||||
|
LA_largeindex newsize = calcsize(); //recalculate auxiliary arrays
|
||||||
|
if(data.size()!=newsize) laerror("internal error in split_index_group");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void Tensor<T>:: merge_adjacent_index_groups(int groupfrom, int groupto)
|
||||||
|
{
|
||||||
|
if(groupfrom<0||groupfrom>= shape.size()) laerror("illegal index group number");
|
||||||
|
if(groupto<0||groupto>= shape.size()) laerror("illegal index group number");
|
||||||
|
if(groupfrom==groupto) return;
|
||||||
|
if(groupfrom>groupto) {int t=groupfrom; groupfrom=groupto; groupto=t;}
|
||||||
|
|
||||||
|
int newnumber=0;
|
||||||
|
for(int g=groupfrom; g<=groupto; ++g)
|
||||||
|
{
|
||||||
|
if(shape[g].symmetry!=0) laerror("only non-symmetric index groups can be merged");
|
||||||
|
if(shape[g].offset!=shape[groupfrom].offset) laerror("incompatible offset in merge_adjacent_index_groups");
|
||||||
|
if(shape[g].range!=shape[groupfrom].range) laerror("incompatible range in merge_adjacent_index_groups");
|
||||||
|
newnumber += shape[g].number;
|
||||||
|
}
|
||||||
|
|
||||||
|
NRVec<indexgroup> newshape(shape.size()-(groupto-groupfrom+1)+1);
|
||||||
|
for(int g=0; g<=groupfrom; ++g) newshape[g]=shape[g];
|
||||||
|
newshape[groupfrom].number=newnumber;
|
||||||
|
for(int g=groupfrom+1; g<newshape.size(); ++g) newshape[g]=shape[g+groupto-groupfrom];
|
||||||
|
|
||||||
|
shape=newshape;
|
||||||
|
LA_largeindex newsize = calcsize(); //recalculate auxiliary arrays
|
||||||
|
if(data.size()!=newsize) laerror("internal error in merge_adjacent_index_groups");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
Tensor<T> Tensor<T>::merge_index_groups(const NRVec<int> &groups) const
|
||||||
|
{
|
||||||
|
if(groups.size()<=1) return *this;
|
||||||
|
NRPerm<int> p(shape.size());
|
||||||
|
int gg=0;
|
||||||
|
bitvector selected(shape.size());
|
||||||
|
selected.clear();
|
||||||
|
for(int g=0; g<groups.size(); ++g)
|
||||||
|
{
|
||||||
|
if(groups[g]<0||groups[g]>=shape.size()) laerror("illegal group number in merge_index_groups");
|
||||||
|
if(selected[g]) laerror("repeated group number in merge_index_groups");
|
||||||
|
selected.set(g);
|
||||||
|
p[gg++] = 1+groups[g];
|
||||||
|
}
|
||||||
|
for(int g=0; g<shape.size(); ++g)
|
||||||
|
if(!selected[g])
|
||||||
|
p[gg++] = 1+g;
|
||||||
|
if(gg!=shape.size() || !p.is_valid()) laerror("internal error in merge_index_groups");
|
||||||
|
|
||||||
|
Tensor<T> r = permute_index_groups(p);
|
||||||
|
r.merge_adjacent_index_groups(0,groups.size()-1);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template class Tensor<double>;
|
template class Tensor<double>;
|
||||||
template class Tensor<std::complex<double> >;
|
template class Tensor<std::complex<double> >;
|
||||||
template std::ostream & operator<<(std::ostream &s, const Tensor<double> &x);
|
template std::ostream & operator<<(std::ostream &s, const Tensor<double> &x);
|
||||||
|
4
tensor.h
4
tensor.h
@ -190,7 +190,9 @@ public:
|
|||||||
void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra<int,T> &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list:
|
void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra<int,T> &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list:
|
||||||
// this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I))
|
// this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I))
|
||||||
// PermutationAlgebra can represent e.g. general_antisymmetrizer in Kucharski-Bartlett notation
|
// PermutationAlgebra can represent e.g. general_antisymmetrizer in Kucharski-Bartlett notation
|
||||||
|
void split_index_group(int group); //formal split of a non-symmetric index group WITHOUT the need for data reorganization
|
||||||
|
void merge_adjacent_index_groups(int groupfrom, int groupto); //formal merge of non-symmetric index groups WITHOUT the need for data reorganization
|
||||||
|
Tensor merge_index_groups(const NRVec<int> &groups) const;
|
||||||
|
|
||||||
//TODO perhaps implement application of a permutation algebra to a product of several tensors
|
//TODO perhaps implement application of a permutation algebra to a product of several tensors
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user