From 8fa7194f2d060dd8dbf4f2b1464abc1df9bf2d3d Mon Sep 17 00:00:00 2001 From: Jiri Pittner Date: Tue, 26 Mar 2024 17:49:09 +0100 Subject: [PATCH] working on tensor --- tensor.cc | 23 +++++++++++++++++++++++ tensor.h | 26 ++++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/tensor.cc b/tensor.cc index a29d39b..1951334 100644 --- a/tensor.cc +++ b/tensor.cc @@ -23,6 +23,29 @@ namespace LA { +template +LA_largeindex Tensor::index(const SUPERINDEX &I) +{ +//check index structure and ranges +#ifndef DEBUG +if(I.size()!=shape.size()) laerror("mismatch in the number of tensor index groups"); +for(int i=0; i= shape[i].offset+shape[i].size()) + { + std::cerr<<"error in index group no. "< +#include #include "vec.h" #include "miscfunc.h" @@ -44,20 +45,41 @@ LA_index offset; //indices start at LA_index size; //indices span this range } INDEXGROUP; +typedef NRVec FLATINDEX; //all indices but in a single vector +typedef NRVec > SUPERINDEX; //all indices in the INDEXGROUP structure + + template class Tensor { NRVec shape; NRVec data; +private: + LA_largeindex index(const SUPERINDEX &I); //map the tensor indices to the position in data + LA_largeindex index(const FLATINDEX &I); //map the tensor indices to the position in data + LA_largeindex vindex(int i1,va_list args); //map list of indices to the position in data @@@must call va_end + public: Tensor() {}; Tensor(const NRVec &s) : shape(s), data((int)size()) {data.clear();}; int rank() const; //is computed from shape LA_largeindex size() const; //expensive, is computed from shape void copyonwrite() {shape.copyonwrite(); data.copyonwrite();}; - //@@@operator() lhs and rhs both via vararg a via superindex of flat and nested types, get/put to file, stream i/o + inline T& operator()(const SUPERINDEX &I) {return data[index(I)];}; + inline const T& operator()(const SUPERINDEX &I) const {return data[index(I)];}; + inline T& operator()(const FLATINDEX &I) {return data[index(I)];}; + inline const T& operator()(const FLATINDEX &I) const {return data[index(I)];}; + inline T& operator()(int i1...) {va_list args; va_start(args,i1); return data[vindex(i1,args)];}; + inline const T& operator()(int i1...) const {va_list args; va_start(args,i1); return data[vindex(i1,args)];}; + + //@@@TODO - unwinding to full size in a specified index + //@@@TODO - contractions - basic and efficient + //@@@TODO get/put to file, stream i/o }; + + + template int Tensor:: rank() const {