tensor: numpy file i/o

This commit is contained in:
2026-03-09 18:34:56 +01:00
parent 302055db86
commit fa59833e1e
4 changed files with 142 additions and 2 deletions

1
.gitignore vendored
View File

@@ -63,4 +63,5 @@ test_regsurf
*.gcov
gmon.out
npytest
*.npy
# CVS default ignores end

35
t.cc
View File

@@ -159,7 +159,7 @@ cout <<y;
int main()
int main(int argc, char **argv)
{
sigtraceback(SIGSEGV,1);
sigtraceback(SIGABRT,1);
@@ -4746,7 +4746,7 @@ cout <<aa;
cout <<"Error = "<<(ax-aa).norm()<<endl;
}
if(1)
if(0)
{
int n=3;
NRVec<INDEXGROUP> s(4);
@@ -4766,5 +4766,36 @@ Tensor<double> t1 = t.subtensor1(2);
cout <<t1;
}
if(0)
{
int n=2;
NRVec<INDEXGROUP> s(4);
for(int i=0; i<4; ++i)
{
s[i].number=1;
s[i].symmetry=0;
s[i].offset=0;
s[i].range=n+i;
}
Tensor<double> t(s);
t.randomize(1.);
t.numpy_write("test.npy","<f8");
Tensor<double> tt;
tt.numpy_read("test.npy");
cout<<t;
cout<<tt;
cout <<"Error = "<<(t-tt).norm()<<endl;
}
if(1)
{
Tensor<double> tt;
tt.numpy_read(argv[1]);
//cout<<tt;
cout<<"part\n"<<tt.subtensor1(0);
cout<<"part\n"<<tt.subtensor1(1);
cout<<"part\n"<<tt.subtensor1(2);
}
}//main

104
tensor.cc
View File

@@ -2514,6 +2514,110 @@ return r;
}
//NOTE: we do not check numpy type matches T
template<typename T>
void Tensor<T>::numpy_write(const char *name, const char *descr) const
{
if(!is_flat()) laerror("numpy_write only for flat tensors");
int fd=open(name,O_CREAT|O_LARGEFILE|O_RDWR,0777);
if(fd<0) laerror("cannot open in numpy_write");
char magic[]="\x93NUMPY\x01\x00";
if(8!=write(fd,magic,8)) laerror("cannot write 1 in numpy_write");
//construct header
char header[2048];
int16_t header_len=0;
sprintf(header,"{'descr': '%s', 'fortran_order': False, 'shape': (",descr);
for(int i=myrank-1; i>=0; --i)
{
header_len=strlen(header);
sprintf(header+header_len,"%d",shape[i].range);
if(i>0)
{
header_len=strlen(header);
sprintf(header+header_len,", ");
}
}
header_len=strlen(header);
sprintf(header+header_len,"), }");
//pad header by spaces to 8 byte boundary and terminate by \n
header_len=strlen(header);
int x= 8 - (header_len+2+8)%8;
for(int i=0; i<x; ++i) header[header_len+i]=' ';
header_len+=x;
header[header_len-1]='\n';
header[header_len]=0;
if(2!=write(fd,&header_len,2)) laerror("cannot write 2 in numpy_write");
if(header_len!=write(fd,header,header_len)) laerror("cannot write 3 in numpy_write");
if(sizeof(T)*data.size()!= write(fd,&data[0],sizeof(T)*data.size())) laerror("cannot write 4 in numpy_write");
if(close(fd)) laerror("cannot close in numpy_write");
}
//NOTE: we do not check numpy type matches T
template<typename T>
void Tensor<T>::numpy_read(const char *name)
{
int fd=open(name,O_LARGEFILE|O_RDONLY);
if(fd<0) laerror("cannot open in numpy_read");
char magic[]="\x93NUMPY\x01\x00";
char readmagic[8];
if(8!=read(fd,readmagic,8)) laerror("cannot read 1 in numpy_read");
if(memcmp(magic,readmagic,8)) laerror("bad magic or version != 1.0 in numpy_read");
int16_t header_len;
if(2!=read(fd,&header_len,2)) laerror("cannot read 2 in numpy_read");
char header[header_len];
if(header_len!=read(fd,header,header_len)) laerror("cannot read 3 in numpy_read");
char *p=strstr(header,"fortran_order");
if(!p) laerror("cannot find fortran_order in numpy_read");
bool fortranorder=false;
p=strstr(p,": ");
if(strncmp(p+2,"False",5)) fortranorder=true;
//std::cout <<"Fortran order "<<fortranorder<<std::endl;
std::list<int> dimslist;
p=strstr(p,"'shape': (");
if(!p) laerror("cannot find shape in numpy_read");
p+=strlen("'shape': (");
int d;
while(1==sscanf(p,"%d",&d))
{
if(fortranorder) dimslist.push_back(d); else dimslist.push_front(d);
while(isdigit(*p)||isspace(*p)) p++;
if(*p==')') break;
++p; //skip ,
}
NRVec<int> dims(dimslist);
if(dims.size()==0) laerror("zero-dimensional array in numpy_read");
shape.resize(dims.size());
for(int i=0; i<dims.size(); ++i)
{
shape[i].number=1;
shape[i].symmetry=0;
#ifndef LA_TENSOR_ZERO_OFFSET
shape[i].offset=0;
#endif
shape[i].range=dims[i];
}
data.resize(calcsize());
calcrank();
canonicalize_shape();
if(sizeof(T)*data.size()!= read(fd,&data[0],sizeof(T)*data.size())) laerror("cannot read 4 in numpy_read");
if(close(fd)) laerror("cannot close in numpy_read");
}
template class Tensor<double>;

View File

@@ -279,6 +279,10 @@ public:
//subtensor - todo: more general versions
Tensor subtensor1(int i) const; //for one value of rightmost index
//numpy interface
void numpy_write(const char *name, const char *descr="<f8") const;
void numpy_read(const char *name);
//conversions from/to matrix and vector
explicit Tensor(const NRVec<T> &x);
explicit Tensor(const NRMat<T> &x, bool flat=false);