tensor: permutation of indices inside a symmetry group

This commit is contained in:
2025-11-18 18:09:04 +01:00
parent 20a61e2fb9
commit 78569ca701
6 changed files with 159 additions and 3 deletions

2
mat.cc
View File

@@ -2137,6 +2137,8 @@ NRMat< std::complex<double> >::operator*(const NRSMat< std::complex<double> > &r
******************************************************************************/
template<typename T>
NRMat<T>& NRMat<T>::conjugateme() {
if(!LA_traits<T>::is_complex()) return *this;
copyonwrite();
#ifdef CUDALA
if(location != cpu) laerror("general conjugation only on CPU");
#endif

View File

@@ -872,6 +872,9 @@ NRSMat<std::complex<double> > NRSMat<std::complex<double> >::inverse() {return
******************************************************************************/
template<typename T>
NRSMat<T>& NRSMat<T>::conjugateme() {
if(!LA_traits<T>::is_complex()) return *this;
copyonwrite();
#ifdef CUDALA
if(location != cpu) laerror("general conjugation only on CPU");
#endif

View File

@@ -142,8 +142,8 @@ switch(I.size()) //a few special cases for efficiency
NRVec<LA_index> II(I);
II.copyonwrite();
if(g.offset!=0) II -= g.offset;
int parity=netsort(II.size(),&II[0]);
*sign= (parity&1) ? g.symmetry : 1;
int nswaps=netsort(II.size(),&II[0]);
*sign= (nswaps&1) ? g.symmetry : 1;
if(g.symmetry == -1) //antisymmetric - do not store zero diagonal
{
for(int i=0; i<I.size()-1; ++i)
@@ -2165,6 +2165,40 @@ apply_permutation_algebra(tmp,pb,true,(T)1,beta);
}
template<typename T>
void Tensor<T>::permute_inside_group(int g, const NRPerm<int> &p, bool inverse)
{
if(g<0||g>=shape.size()) laerror("group out of range");
if(shape[g].symmetry==0) laerror("permutation possible only inside symmetric index groups");
if(shape[g].number!=p.size()) laerror("permutation size mismatch to index number");
if(!p.is_valid()) laerror("invalid permutation in permute_inside_group");
int par=p.parity();
if(par<0)
{
switch(shape[g].symmetry)
{
case 1:
break;
case -1:
data.negateme();
break;
case 2:
data.conjugateme();
break;
case -2:
data.negateconjugateme();
break;
}
}
if(is_named())
{
int ii=0;
for(int i=0; i<g; ++i) ii+= shape[i].number;
NRVec<INDEXNAME> tmp=(names.subvector(ii,ii+p.size()-1)).permuted(p,inverse);
for(int i=0; i<p.size(); ++i) names[ii+i]=tmp[i];
}
}
template class Tensor<double>;

View File

@@ -384,6 +384,7 @@ public:
void split_index_group(int group); //formal in-place split of a non-symmetric index group WITHOUT the need for data reorganization or names rearrangement
void split_index_group1(int group); //formal in-place split of the leftmost index in a non-symmetric index group WITHOUT the need for data reorganization or names rearrangement
void merge_adjacent_index_groups(int groupfrom, int groupto); //formal merge of non-symmetric index groups WITHOUT the need for data reorganization or names rearrangement
void permute_inside_group(int group, const NRPerm<int> &p, bool inverse=false); //permute indices inside a symmetric index group only
Tensor merge_index_groups(const NRVec<int> &groups) const;
Tensor flatten(int group= -1) const; //split and uncompress a given group or all of them, leaving flat index order the same
@@ -653,7 +654,6 @@ NRMat<T> mat(t.data,range*(range-1)/2,range*(range-1)/2);
f=fourindex_dense<antisymtwoelectronrealdirac,T,I>(range,NRSMat<T>(mat)); //symmetrize mat
}
//@@@formal permutation of names inside a sym/antisy group (with possible sign change)
template <typename T>

114
vec.cc
View File

@@ -869,6 +869,7 @@ return -1;
******************************************************************************/
template<typename T>
NRVec<T>& NRVec<T>::conjugateme() {
if(!LA_traits<T>::is_complex()) return *this;
copyonwrite();
#ifdef CUDALA
if(location != cpu) laerror("general conjugation only on CPU");
@@ -877,6 +878,29 @@ copyonwrite();
return *this;
}
template<typename T>
NRVec<T>& NRVec<T>::negateconjugateme() {
if(!LA_traits<T>::is_complex()) negateme();
copyonwrite();
#ifdef CUDALA
if(location != cpu) laerror("general conjugation only on CPU");
#endif
for(int i=0; i<nn; ++i) v[i] = -LA_traits<T>::conjugate(v[i]);
return *this;
}
template<typename T>
NRVec<T>& NRVec<T>::negateme() {
copyonwrite();
#ifdef CUDALA
if(location != cpu) laerror("general conjugation only on CPU");
#endif
for(int i=0; i<nn; ++i) v[i] = -v[i];
return *this;
}
/***************************************************************************//**
* conjugate this complex vector
@@ -912,6 +936,96 @@ NRVec<std::complex<float> >& NRVec<std::complex<float> >::conjugateme() {
return *this;
}
template<>
NRVec<std::complex<double> >& NRVec<std::complex<double> >::negateconjugateme() {
copyonwrite();
#ifdef CUDALA
if(location == cpu){
#endif
cblas_dscal((size_t)nn, -1.0, ((double *)v) , 2);
#ifdef CUDALA
}else{
cublasDscal((size_t)nn, -1.0, ((double *)v) , 2);
}
#endif
return *this;
}
template<>
NRVec<std::complex<float> >& NRVec<std::complex<float> >::negateconjugateme() {
copyonwrite();
#ifdef CUDALA
if(location == cpu){
#endif
cblas_sscal((size_t)nn, -1.0, ((float *)v) , 2);
#ifdef CUDALA
}else{
cublasSscal((size_t)nn, -1.0, ((float *)v) , 2);
}
#endif
return *this;
}
template<>
NRVec<std::complex<double> >& NRVec<std::complex<double> >::negateme() {
copyonwrite();
#ifdef CUDALA
if(location == cpu){
#endif
cblas_dscal((size_t)nn*2, -1.0, ((double *)v) , 1);
#ifdef CUDALA
}else{
cublasDscal((size_t)nn*2, -1.0, ((double *)v) , 1);
}
#endif
return *this;
}
template<>
NRVec<std::complex<float> >& NRVec<std::complex<float> >::negateme() {
copyonwrite();
#ifdef CUDALA
if(location == cpu){
#endif
cblas_sscal((size_t)nn*2, -1.0, ((float *)v) , 1);
#ifdef CUDALA
}else{
cublasSscal((size_t)nn*2, -1.0, ((float *)v) , 1);
}
#endif
return *this;
}
template<>
NRVec<double>& NRVec<double>::negateme() {
copyonwrite();
#ifdef CUDALA
if(location == cpu){
#endif
cblas_dscal((size_t)nn, -1.0, v , 1);
#ifdef CUDALA
}else{
cublasDscal((size_t)nn, -1.0, v , 1);
}
#endif
return *this;
}
template<>
NRVec<float>& NRVec<float>::negateme() {
copyonwrite();
#ifdef CUDALA
if(location == cpu){
#endif
cblas_sscal((size_t)nn, -1.0, v , 1);
#ifdef CUDALA
}else{
cublasSscal((size_t)nn, -1.0, v , 1);
}
#endif
return *this;
}
/***************************************************************************//**
* sum up the elements of current vector of general type <code>T</code>

3
vec.h
View File

@@ -303,6 +303,9 @@ public:
NRVec& conjugateme();
inline NRVec conjugate() const {NRVec r(*this); r.conjugateme(); return r;};
NRVec& negateconjugateme();
NRVec& negateme();
//! determine the actual value of the reference counter
inline int getcount() const {return count?*count:0;}