From 78569ca7016f4c194596a5837a8396fa15ff2f7a Mon Sep 17 00:00:00 2001 From: Jiri Pittner Date: Tue, 18 Nov 2025 18:09:04 +0100 Subject: [PATCH] tensor: permutation of indices inside a symmetry group --- mat.cc | 2 + smat.cc | 3 ++ tensor.cc | 38 +++++++++++++++++- tensor.h | 2 +- vec.cc | 114 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ vec.h | 3 ++ 6 files changed, 159 insertions(+), 3 deletions(-) diff --git a/mat.cc b/mat.cc index df652fa..05cb918 100644 --- a/mat.cc +++ b/mat.cc @@ -2137,6 +2137,8 @@ NRMat< std::complex >::operator*(const NRSMat< std::complex > &r ******************************************************************************/ template NRMat& NRMat::conjugateme() { +if(!LA_traits::is_complex()) return *this; +copyonwrite(); #ifdef CUDALA if(location != cpu) laerror("general conjugation only on CPU"); #endif diff --git a/smat.cc b/smat.cc index 167f12a..aaceec8 100644 --- a/smat.cc +++ b/smat.cc @@ -872,6 +872,9 @@ NRSMat > NRSMat >::inverse() {return ******************************************************************************/ template NRSMat& NRSMat::conjugateme() { +if(!LA_traits::is_complex()) return *this; +copyonwrite(); + #ifdef CUDALA if(location != cpu) laerror("general conjugation only on CPU"); #endif diff --git a/tensor.cc b/tensor.cc index 857424a..ee42b5b 100644 --- a/tensor.cc +++ b/tensor.cc @@ -142,8 +142,8 @@ switch(I.size()) //a few special cases for efficiency NRVec II(I); II.copyonwrite(); if(g.offset!=0) II -= g.offset; - int parity=netsort(II.size(),&II[0]); - *sign= (parity&1) ? g.symmetry : 1; + int nswaps=netsort(II.size(),&II[0]); + *sign= (nswaps&1) ? g.symmetry : 1; if(g.symmetry == -1) //antisymmetric - do not store zero diagonal { for(int i=0; i +void Tensor::permute_inside_group(int g, const NRPerm &p, bool inverse) +{ +if(g<0||g>=shape.size()) laerror("group out of range"); +if(shape[g].symmetry==0) laerror("permutation possible only inside symmetric index groups"); +if(shape[g].number!=p.size()) laerror("permutation size mismatch to index number"); +if(!p.is_valid()) laerror("invalid permutation in permute_inside_group"); +int par=p.parity(); +if(par<0) + { + switch(shape[g].symmetry) + { + case 1: + break; + case -1: + data.negateme(); + break; + case 2: + data.conjugateme(); + break; + case -2: + data.negateconjugateme(); + break; + } + } +if(is_named()) + { + int ii=0; + for(int i=0; i tmp=(names.subvector(ii,ii+p.size()-1)).permuted(p,inverse); + for(int i=0; i; diff --git a/tensor.h b/tensor.h index 3ae3129..5fb4184 100644 --- a/tensor.h +++ b/tensor.h @@ -384,6 +384,7 @@ public: void split_index_group(int group); //formal in-place split of a non-symmetric index group WITHOUT the need for data reorganization or names rearrangement void split_index_group1(int group); //formal in-place split of the leftmost index in a non-symmetric index group WITHOUT the need for data reorganization or names rearrangement void merge_adjacent_index_groups(int groupfrom, int groupto); //formal merge of non-symmetric index groups WITHOUT the need for data reorganization or names rearrangement + void permute_inside_group(int group, const NRPerm &p, bool inverse=false); //permute indices inside a symmetric index group only Tensor merge_index_groups(const NRVec &groups) const; Tensor flatten(int group= -1) const; //split and uncompress a given group or all of them, leaving flat index order the same @@ -653,7 +654,6 @@ NRMat mat(t.data,range*(range-1)/2,range*(range-1)/2); f=fourindex_dense(range,NRSMat(mat)); //symmetrize mat } -//@@@formal permutation of names inside a sym/antisy group (with possible sign change) template diff --git a/vec.cc b/vec.cc index 3eeec6c..0e687c6 100644 --- a/vec.cc +++ b/vec.cc @@ -869,6 +869,7 @@ return -1; ******************************************************************************/ template NRVec& NRVec::conjugateme() { +if(!LA_traits::is_complex()) return *this; copyonwrite(); #ifdef CUDALA if(location != cpu) laerror("general conjugation only on CPU"); @@ -877,6 +878,29 @@ copyonwrite(); return *this; } +template +NRVec& NRVec::negateconjugateme() { +if(!LA_traits::is_complex()) negateme(); +copyonwrite(); +#ifdef CUDALA + if(location != cpu) laerror("general conjugation only on CPU"); +#endif + for(int i=0; i::conjugate(v[i]); + return *this; +} + + +template +NRVec& NRVec::negateme() { +copyonwrite(); +#ifdef CUDALA + if(location != cpu) laerror("general conjugation only on CPU"); +#endif + for(int i=0; i >& NRVec >::conjugateme() { return *this; } +template<> +NRVec >& NRVec >::negateconjugateme() { + copyonwrite(); +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_dscal((size_t)nn, -1.0, ((double *)v) , 2); +#ifdef CUDALA + }else{ + cublasDscal((size_t)nn, -1.0, ((double *)v) , 2); + } +#endif + return *this; +} + +template<> +NRVec >& NRVec >::negateconjugateme() { + copyonwrite(); +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_sscal((size_t)nn, -1.0, ((float *)v) , 2); +#ifdef CUDALA + }else{ + cublasSscal((size_t)nn, -1.0, ((float *)v) , 2); + } +#endif + return *this; +} + +template<> +NRVec >& NRVec >::negateme() { + copyonwrite(); +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_dscal((size_t)nn*2, -1.0, ((double *)v) , 1); +#ifdef CUDALA + }else{ + cublasDscal((size_t)nn*2, -1.0, ((double *)v) , 1); + } +#endif + return *this; +} + +template<> +NRVec >& NRVec >::negateme() { + copyonwrite(); +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_sscal((size_t)nn*2, -1.0, ((float *)v) , 1); +#ifdef CUDALA + }else{ + cublasSscal((size_t)nn*2, -1.0, ((float *)v) , 1); + } +#endif + return *this; +} + +template<> +NRVec& NRVec::negateme() { + copyonwrite(); +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_dscal((size_t)nn, -1.0, v , 1); +#ifdef CUDALA + }else{ + cublasDscal((size_t)nn, -1.0, v , 1); + } +#endif + return *this; +} + +template<> +NRVec& NRVec::negateme() { + copyonwrite(); +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_sscal((size_t)nn, -1.0, v , 1); +#ifdef CUDALA + }else{ + cublasSscal((size_t)nn, -1.0, v , 1); + } +#endif + return *this; +} + /***************************************************************************//** * sum up the elements of current vector of general type T diff --git a/vec.h b/vec.h index 6e722ee..3b8ebbe 100644 --- a/vec.h +++ b/vec.h @@ -303,6 +303,9 @@ public: NRVec& conjugateme(); inline NRVec conjugate() const {NRVec r(*this); r.conjugateme(); return r;}; + NRVec& negateconjugateme(); + NRVec& negateme(); + //! determine the actual value of the reference counter inline int getcount() const {return count?*count:0;}