From 883d201e67127d22076fa57a67e4ff001c91c77a Mon Sep 17 00:00:00 2001 From: Jiri Pittner Date: Mon, 6 May 2024 18:30:01 +0200 Subject: [PATCH] tensor: split and merge index groups --- tensor.cc | 88 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ tensor.h | 4 ++- 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/tensor.cc b/tensor.cc index cb4b3f9..0ca3599 100644 --- a/tensor.cc +++ b/tensor.cc @@ -22,6 +22,7 @@ #include "qsort.h" #include "miscfunc.h" #include +#include "bitvector.h" namespace LA { @@ -752,6 +753,93 @@ loopover(permutationalgebra_callback); } +template +void Tensor::split_index_group(int group) +{ +if(group<0||group >= shape.size()) laerror("illegal index group number"); +if(shape[group].number==1) return; //nothing to split +if(shape[group].symmetry!=0) laerror("only non-symmetric index group can be splitted"); + +NRVec newshape(shape.size()+shape[group].number-1); +int gg=0; +for(int g=0; g +void Tensor:: merge_adjacent_index_groups(int groupfrom, int groupto) +{ +if(groupfrom<0||groupfrom>= shape.size()) laerror("illegal index group number"); +if(groupto<0||groupto>= shape.size()) laerror("illegal index group number"); +if(groupfrom==groupto) return; +if(groupfrom>groupto) {int t=groupfrom; groupfrom=groupto; groupto=t;} + +int newnumber=0; +for(int g=groupfrom; g<=groupto; ++g) + { + if(shape[g].symmetry!=0) laerror("only non-symmetric index groups can be merged"); + if(shape[g].offset!=shape[groupfrom].offset) laerror("incompatible offset in merge_adjacent_index_groups"); + if(shape[g].range!=shape[groupfrom].range) laerror("incompatible range in merge_adjacent_index_groups"); + newnumber += shape[g].number; + } + +NRVec newshape(shape.size()-(groupto-groupfrom+1)+1); +for(int g=0; g<=groupfrom; ++g) newshape[g]=shape[g]; +newshape[groupfrom].number=newnumber; +for(int g=groupfrom+1; g +Tensor Tensor::merge_index_groups(const NRVec &groups) const +{ +if(groups.size()<=1) return *this; +NRPerm p(shape.size()); +int gg=0; +bitvector selected(shape.size()); +selected.clear(); +for(int g=0; g=shape.size()) laerror("illegal group number in merge_index_groups"); + if(selected[g]) laerror("repeated group number in merge_index_groups"); + selected.set(g); + p[gg++] = 1+groups[g]; + } +for(int g=0; g r = permute_index_groups(p); +r.merge_adjacent_index_groups(0,groups.size()-1); +return r; +} + + + + + template class Tensor; template class Tensor >; template std::ostream & operator<<(std::ostream &s, const Tensor &x); diff --git a/tensor.h b/tensor.h index b34b171..c71f86a 100644 --- a/tensor.h +++ b/tensor.h @@ -190,7 +190,9 @@ public: void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list: // this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I)) // PermutationAlgebra can represent e.g. general_antisymmetrizer in Kucharski-Bartlett notation - + void split_index_group(int group); //formal split of a non-symmetric index group WITHOUT the need for data reorganization + void merge_adjacent_index_groups(int groupfrom, int groupto); //formal merge of non-symmetric index groups WITHOUT the need for data reorganization + Tensor merge_index_groups(const NRVec &groups) const; //TODO perhaps implement application of a permutation algebra to a product of several tensors };