From 87dd0c5b65e54de6a8801f3f4efd51cccc89ff40 Mon Sep 17 00:00:00 2001 From: Jiri Pittner Date: Thu, 4 Apr 2024 12:12:12 +0200 Subject: [PATCH] working on tensor --- tensor.cc | 21 +++++++++++++++++++++ tensor.h | 28 +++++++++++++++++++--------- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/tensor.cc b/tensor.cc index 973dc88..ac30ef3 100644 --- a/tensor.cc +++ b/tensor.cc @@ -158,6 +158,27 @@ LA_largeindex Tensor::vindex(int *sign, int i1, va_list args) const } +//binary I/O + +template +void Tensor::put(int fd) const +{ +shape.put(fd,true); +cumsizes.put(fd,true); +data.put(fd,true); +} + +template +void Tensor::get(int fd) +{ +shape.get(fd,true); +cumsizes.get(fd,true); +data.get(fd,true); +} + + + + template class Tensor; template class Tensor >; diff --git a/tensor.h b/tensor.h index 103cc9c..fc44419 100644 --- a/tensor.h +++ b/tensor.h @@ -43,8 +43,12 @@ T *ptr; int sgn; public: Signedpointer(T *p, int s) : ptr(p),sgn(s) {}; - T& operator=(const T rhs) {if(sgn==0) return *ptr; if(sgn>0) *ptr=rhs; else *ptr = -rhs; return *ptr;} -//@@@@@@operations like += etc. on singedpointer as LHS of the non-const tensor.operator() expressions + //dereferencing *ptr should intentionally segfault for sgn==0 + T& operator=(const T rhs) {if(sgn>0) *ptr=rhs; else *ptr = -rhs; return *ptr;} + T& operator*=(const T rhs) {*ptr *= rhs; return *ptr;} + T& operator/=(const T rhs) {*ptr /= rhs; return *ptr;} + T& operator+=(const T rhs) {if(sgn>0) *ptr += rhs; else *ptr -= rhs; return *ptr;} + T& operator-=(const T rhs) {if(sgn>0) *ptr -= rhs; else *ptr += rhs; return *ptr;} }; @@ -65,6 +69,11 @@ class LA_traits { static bool is_plaindata() {return true;}; static void copyonwrite(indexgroup& x) {}; typedef INDEXGROUP normtype; + static inline void put(int fd, const indexgroup &x, bool dimensions=1) {if(sizeof(indexgroup)!=write(fd,&x,sizeof(indexgroup))) laerror("write error 1 in indexgroup put"); } + static inline void multiput(int nn, int fd, const indexgroup *x, bool dimensions=1) {if(nn*sizeof(indexgroup)!=write(fd,x,nn*sizeof(indexgroup))) laerror("write error 1 in indexgroup multiiput"); } + static inline void get(int fd, indexgroup &x, bool dimensions=1) {if(sizeof(indexgroup)!=read(fd,&x,sizeof(indexgroup))) laerror("read error 1 in indexgroup get");} + static inline void multiget(int nn, int fd, indexgroup *x, bool dimensions=1) {if(nn*sizeof(indexgroup)!=read(fd,x,nn*sizeof(indexgroup))) laerror("read error 1 in indexgroup get");} + }; @@ -74,13 +83,10 @@ typedef NRVec > SUPERINDEX; //all indices in the INDEXGROUP stru template class Tensor { -//essential data NRVec shape; + NRVec cumsizes; //cumulative sizes of symmetry index groups (a function of shape but precomputed for efficiency) NRVec data; -//redundant data to facilitate efficient indexing - NRVec cumsizes; //cumulative sizes of symmetry index groups - private: LA_largeindex index(int *sign, const SUPERINDEX &I) const; //map the tensor indices to the position in data LA_largeindex index(int *sign, const FLATINDEX &I) const; //map the tensor indices to the position in data @@ -102,14 +108,18 @@ public: inline T operator()(const FLATINDEX &I) {int sign; LA_largeindex i=index(&sign,I); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];}; inline Signedpointer lhs(int i1...) {va_list args; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); return Signedpointer(&data[i],sign); }; inline T operator()(int i1...) {va_list args; ; int sign; LA_largeindex i; va_start(args,i1); i= vindex(&sign, i1,args); if(sign==0) return 0; return sign>0 ?data[i] : -data[i];}; - - //NOTE: for sign==0 data[i] can be undefined pointer, avoid dereferencing it inline Tensor& operator*=(const T &a) {data*=a; return *this;}; + inline Tensor& operator/=(const T &a) {data/=a; return *this;}; + + void put(int fd) const; + void get(int fd); //@@@TODO - unwinding to full size in a specified index //@@@TODO - contractions - basic and efficient - //@@@TODO get/put to file, stream i/o + //@@@dodelat indexy + //@@@ dvojite rekurzivni loopover s callbackem - nebo iterator s funkci next??? + //@@@ stream i/o na zaklade tohoto };