tensor: linear_transform implemented

This commit is contained in:
2025-11-20 18:17:34 +01:00
parent d136c2314d
commit ad1c4ee968
3 changed files with 122 additions and 23 deletions

View File

@@ -83,6 +83,15 @@ if(LA_traits<T>::is_complex()) //condition known at compile time
}
static inline bool ptr_ok(void *ptr)
{
#ifdef DEBUG
if(ptr==NULL) std::cout<<"Warning: write to nonexistent tensor element ignored\n";
#endif
return (bool)ptr;
}
template<typename T>
class Signedpointer
{
@@ -91,11 +100,11 @@ int sgn;
public:
Signedpointer(T *p, int s) : ptr(p),sgn(s) {};
//dereferencing *ptr should be ignored for sgn==0
const T operator=(const T rhs) {*ptr = signeddata(sgn,rhs); return rhs;}
void operator*=(const T rhs) {*ptr *= rhs;}
void operator/=(const T rhs) {*ptr /= rhs;}
void operator+=(T rhs) {*ptr += signeddata(sgn,rhs);}
void operator-=(T rhs) {*ptr -= signeddata(sgn,rhs);}
const T operator=(const T rhs) {if(ptr_ok(ptr)) *ptr = signeddata(sgn,rhs); return rhs;}
void operator*=(const T rhs) {if(ptr_ok(ptr)) *ptr *= rhs;}
void operator/=(const T rhs) {if(ptr_ok(ptr)) *ptr /= rhs;}
void operator+=(T rhs) {if(ptr_ok(ptr)) *ptr += signeddata(sgn,rhs);}
void operator-=(T rhs) {if(ptr_ok(ptr)) *ptr -= signeddata(sgn,rhs);}
};
@@ -182,6 +191,8 @@ struct INDEX
int group;
int index;
bool operator==(const INDEX &rhs) const {return group==rhs.group && index==rhs.index;};
INDEX() {};
INDEX(int g, int i): group(g), index(i) {};
};
typedef NRVec<INDEX> INDEXLIST; //collection of several indices
@@ -223,13 +234,16 @@ public:
int myrank;
NRVec<LA_largeindex> groupsizes; //group sizes of symmetry index groups (a function of shape but precomputed for efficiency)
NRVec<LA_largeindex> cumsizes; //cumulative sizes of symmetry index groups (a function of shape but precomputed for efficiency); always cumsizes[0]=1, index group 0 is the innermost-loop one
public:
NRVec<INDEXNAME> names;
//indexing
LA_largeindex index(int *sign, const SUPERINDEX &I) const; //map the tensor indices to the position in data
LA_largeindex index(int *sign, const FLATINDEX &I) const; //map the tensor indices to the position in data
LA_largeindex vindex(int *sign, LA_index i1, va_list args) const; //map list of indices to the position in data
SUPERINDEX inverse_index(LA_largeindex s) const; //inefficient, but possible if needed
int myflatposition(int group, int index) const {return flatposition(group,index,shape);};
int myflatposition(const INDEX &i) const {return flatposition(i,shape);};
INDEX myindexposition(int flatindex) const {return indexposition(flatindex,shape);};
//constructors
Tensor() : myrank(-1) {};
@@ -273,22 +287,16 @@ public:
void deallocate() {data.resize(0); shape.resize(0); groupsizes.resize(0); cumsizes.resize(0); names.resize(0);};
inline Signedpointer<T> lhs(const SUPERINDEX &I) {int sign; LA_largeindex i=index(&sign,I);
#ifdef DEBUG
if(sign==0) laerror("l-value pointer to nonexistent tensor element");
#endif
return Signedpointer<T>(&data[i],sign);};
if(sign==0) return Signedpointer<T>(NULL,sign);
else return Signedpointer<T>(&data[i],sign);};
inline T operator()(const SUPERINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; else return signeddata(sign,data[i]);};
inline Signedpointer<T> lhs(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I);
#ifdef DEBUG
if(sign==0) laerror("l-value pointer to nonexistent tensor element");
#endif
return Signedpointer<T>(&data[i],sign);};
if(sign==0) return Signedpointer<T>(NULL,sign);
else return Signedpointer<T>(&data[i],sign);};
inline T operator()(const FLATINDEX &I) const {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; else return signeddata(sign,data[i]);};
inline Signedpointer<T> lhs(LA_index i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args);
#ifdef DEBUG
if(sign==0) laerror("l-value pointer to nonexistent tensor element");
#endif
return Signedpointer<T>(&data[i],sign); };
if(sign==0) return Signedpointer<T>(NULL,sign);
else return Signedpointer<T>(&data[i],sign); };
inline T operator()(LA_index i1...) const {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; else return signeddata(sign,data[i]);};
inline Tensor& operator=(const Tensor &rhs) {myrank=rhs.myrank; shape=rhs.shape; groupsizes=rhs.groupsizes; cumsizes=rhs.cumsizes; data=rhs.data; names=rhs.names; return *this;};
@@ -406,6 +414,7 @@ public:
NRVec<NRMat<T> > Tucker(typename LA_traits<T>::normtype thr=1e-12, bool inverseorder=false); //HOSVD-Tucker decomposition, return core tensor in *this, flattened
Tensor inverseTucker(const NRVec<NRMat<T> > &x, bool inverseorder=false) const; //rebuild the original tensor from Tucker
Tensor linear_transform(const NRVec<NRMat<T> > &x) const; //linear transform by a different matrix per each index group, preserving group symmetries
};