tensor: support for complex (anti)hermitian tensors

This commit is contained in:
2025-11-18 17:30:58 +01:00
parent 417a7d1d1a
commit 20a61e2fb9
5 changed files with 101 additions and 36 deletions

View File

@@ -19,8 +19,9 @@
//a simple tensor class with arbitrary symmetry of index subgroups
//stored in an efficient way
//indices can optionally have names and by handled by name
//each index group has a specific symmetry (nosym,sym,antisym)
//indices can optionally have names and be handled by name
//each index group has a specific symmetry (antihermitean= -2, antisym= -1, nosymmetry= 0, symmetric= 1,hermitean=2)
//NOTE: diagonal elements of antihermitean and hermitean matrices are stored including the zero imag/real part and the zeroness is NOT checked and similarly for higher rank tensors
//additional symmetry between index groups (like in 2-electron integrals) is not supported directly, you would need to nest the class to Tensor<Tensor<T> >
//leftmost index is least significant (changing fastest) in the storage order
//presently only a rudimentary implementation
@@ -48,6 +49,45 @@
namespace LA {
template<typename T>
inline T signeddata(const int sgn, const T data, const bool lhs=false)
{
if(LA_traits<T>::is_complex()) //condition known at compile time
{
switch(sgn)
{
case 2:
return LA_traits<T>::conjugate(data);
break;
case 1:
return data;
break;
case -1:
return -data;
break;
case -2:
return -LA_traits<T>::conjugate(data);
break;
case 0:
#ifdef DEBUG
if(lhs) laerror("dereferencing lhs Signedpointer to nonexistent tensor element");
#endif
return 0;
break;
}
return 0;
}
else // for real
{
if(sgn>0) return data;
if(sgn<0) return -data;
#ifdef DEBUG
if(sgn==0 && lhs) laerror("dereferencing lhs Signedpointer to nonexistent tensor element");
#endif
return 0;
}
}
template<typename T>
class Signedpointer
@@ -57,19 +97,11 @@ int sgn;
public:
Signedpointer(T *p, int s) : ptr(p),sgn(s) {};
//dereferencing *ptr should be ignored for sgn==0
const T operator=(const T rhs)
{
if(sgn>0) *ptr = rhs;
if(sgn<0) *ptr = -rhs;
#ifdef DEBUG
if(sgn==0) laerror("dereferencing lhs Signedpointer to nonexistent tensor element");
#endif
return rhs;
}
T& operator*=(const T rhs) {*ptr *= rhs; return *ptr;}
T& operator/=(const T rhs) {*ptr /= rhs; return *ptr;}
T& operator+=(const T rhs) {if(sgn>0) *ptr += rhs; else *ptr -= rhs; return *ptr;}
T& operator-=(const T rhs) {if(sgn>0) *ptr -= rhs; else *ptr += rhs; return *ptr;}
const T operator=(const T rhs) {*ptr = signeddata(sgn,rhs); return rhs;}
void operator*=(const T rhs) {*ptr *= rhs;}
void operator/=(const T rhs) {*ptr /= rhs;}
void operator+=(T rhs) {*ptr += signeddata(sgn,rhs);}
void operator-=(T rhs) {*ptr -= signeddata(sgn,rhs);}
};
@@ -104,7 +136,7 @@ class LA_traits<INDEXNAME> {
typedef class INDEXGROUP {
public:
int number; //number of indices
int symmetry; //-1 0 or 1, later 2 for hermitian and -2 for antihermitian? - would need change in operator() and Signedpointer
int symmetry; //-1 0 or 1, later 2 for hermitian and -2 for antihermitian
#ifdef LA_TENSOR_ZERO_OFFSET
static const LA_index offset = 0; //compiler can optimize away some computations
#else
@@ -229,6 +261,7 @@ public:
bool is_flat() const {for(int i=0; i<shape.size(); ++i) if(shape[i].number>1) return false; return true;};
bool is_compressed() const {for(int i=0; i<shape.size(); ++i) if(shape[i].number>1&&shape[i].symmetry!=0) return true; return false;};
bool has_symmetry() const {for(int i=0; i<shape.size(); ++i) if(shape[i].symmetry!=0) return true; return false;};
bool has_hermiticity() const {if(!LA_traits<T>::is_complex()) return false; for(int i=0; i<shape.size(); ++i) if(shape[i].symmetry < -1 || shape[i].symmetry > 1) return true; return false;};
void clear() {data.clear();};
void defaultnames(const char *basename="i") {names.resize(rank()); for(int i=0; i<rank(); ++i) sprintf(names[i].name,"%s%03d",basename,i);}
int rank() const {return myrank;};
@@ -239,12 +272,13 @@ public:
void copyonwrite() {shape.copyonwrite(); groupsizes.copyonwrite(); cumsizes.copyonwrite(); data.copyonwrite(); names.copyonwrite();};
void resize(const NRVec<INDEXGROUP> &s) {shape=s; data.resize(calcsize()); calcrank(); names.clear();};
void deallocate() {data.resize(0); shape.resize(0); groupsizes.resize(0); cumsizes.resize(0); names.resize(0);};
inline Signedpointer<T> lhs(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
inline T operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
inline T operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); return signeddata(sign,data[i]);};
inline Signedpointer<T> lhs(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); return Signedpointer<T>(&data[i],sign);};
inline T operator()(const FLATINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
inline T operator()(const FLATINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); return signeddata(sign,data[i]);};
inline Signedpointer<T> lhs(LA_index i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer<T>(&data[i],sign); };
inline T operator()(LA_index i1...) const {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];};
inline T operator()(LA_index i1...) const {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return signeddata(sign,data[i]);};
inline Tensor& operator=(const Tensor &rhs) {myrank=rhs.myrank; shape=rhs.shape; groupsizes=rhs.groupsizes; cumsizes=rhs.cumsizes; data=rhs.data; names=rhs.names; return *this;};
@@ -298,7 +332,7 @@ public:
void put(int fd, bool with_names=false) const;
void get(int fd, bool with_names=false);
inline void randomize(const typename LA_traits<T>::normtype &x) {data.randomize(x);};
inline void randomize(const typename LA_traits<T>::normtype &x) {if(has_hermiticity()) laerror("randomization does not support correct treatment of hermitean/antihermitean index groups"); data.randomize(x);};
void loopover(void (*callback)(const SUPERINDEX &, T *)); //loop over all elements
void constloopover(void (*callback)(const SUPERINDEX &, const T *)) const; //loop over all elements