tensor: numpy file i/o
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -63,4 +63,5 @@ test_regsurf
|
|||||||
*.gcov
|
*.gcov
|
||||||
gmon.out
|
gmon.out
|
||||||
npytest
|
npytest
|
||||||
|
*.npy
|
||||||
# CVS default ignores end
|
# CVS default ignores end
|
||||||
|
|||||||
35
t.cc
35
t.cc
@@ -159,7 +159,7 @@ cout <<y;
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
int main()
|
int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
sigtraceback(SIGSEGV,1);
|
sigtraceback(SIGSEGV,1);
|
||||||
sigtraceback(SIGABRT,1);
|
sigtraceback(SIGABRT,1);
|
||||||
@@ -4746,7 +4746,7 @@ cout <<aa;
|
|||||||
cout <<"Error = "<<(ax-aa).norm()<<endl;
|
cout <<"Error = "<<(ax-aa).norm()<<endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(1)
|
if(0)
|
||||||
{
|
{
|
||||||
int n=3;
|
int n=3;
|
||||||
NRVec<INDEXGROUP> s(4);
|
NRVec<INDEXGROUP> s(4);
|
||||||
@@ -4766,5 +4766,36 @@ Tensor<double> t1 = t.subtensor1(2);
|
|||||||
cout <<t1;
|
cout <<t1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(0)
|
||||||
|
{
|
||||||
|
int n=2;
|
||||||
|
NRVec<INDEXGROUP> s(4);
|
||||||
|
for(int i=0; i<4; ++i)
|
||||||
|
{
|
||||||
|
s[i].number=1;
|
||||||
|
s[i].symmetry=0;
|
||||||
|
s[i].offset=0;
|
||||||
|
s[i].range=n+i;
|
||||||
|
}
|
||||||
|
Tensor<double> t(s);
|
||||||
|
t.randomize(1.);
|
||||||
|
t.numpy_write("test.npy","<f8");
|
||||||
|
|
||||||
|
Tensor<double> tt;
|
||||||
|
tt.numpy_read("test.npy");
|
||||||
|
cout<<t;
|
||||||
|
cout<<tt;
|
||||||
|
cout <<"Error = "<<(t-tt).norm()<<endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(1)
|
||||||
|
{
|
||||||
|
Tensor<double> tt;
|
||||||
|
tt.numpy_read(argv[1]);
|
||||||
|
//cout<<tt;
|
||||||
|
cout<<"part\n"<<tt.subtensor1(0);
|
||||||
|
cout<<"part\n"<<tt.subtensor1(1);
|
||||||
|
cout<<"part\n"<<tt.subtensor1(2);
|
||||||
|
}
|
||||||
|
|
||||||
}//main
|
}//main
|
||||||
|
|||||||
104
tensor.cc
104
tensor.cc
@@ -2514,6 +2514,110 @@ return r;
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//NOTE: we do not check numpy type matches T
|
||||||
|
template<typename T>
|
||||||
|
void Tensor<T>::numpy_write(const char *name, const char *descr) const
|
||||||
|
{
|
||||||
|
if(!is_flat()) laerror("numpy_write only for flat tensors");
|
||||||
|
int fd=open(name,O_CREAT|O_LARGEFILE|O_RDWR,0777);
|
||||||
|
if(fd<0) laerror("cannot open in numpy_write");
|
||||||
|
char magic[]="\x93NUMPY\x01\x00";
|
||||||
|
if(8!=write(fd,magic,8)) laerror("cannot write 1 in numpy_write");
|
||||||
|
|
||||||
|
//construct header
|
||||||
|
char header[2048];
|
||||||
|
int16_t header_len=0;
|
||||||
|
sprintf(header,"{'descr': '%s', 'fortran_order': False, 'shape': (",descr);
|
||||||
|
for(int i=myrank-1; i>=0; --i)
|
||||||
|
{
|
||||||
|
header_len=strlen(header);
|
||||||
|
sprintf(header+header_len,"%d",shape[i].range);
|
||||||
|
if(i>0)
|
||||||
|
{
|
||||||
|
header_len=strlen(header);
|
||||||
|
sprintf(header+header_len,", ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
header_len=strlen(header);
|
||||||
|
sprintf(header+header_len,"), }");
|
||||||
|
|
||||||
|
//pad header by spaces to 8 byte boundary and terminate by \n
|
||||||
|
header_len=strlen(header);
|
||||||
|
int x= 8 - (header_len+2+8)%8;
|
||||||
|
for(int i=0; i<x; ++i) header[header_len+i]=' ';
|
||||||
|
header_len+=x;
|
||||||
|
header[header_len-1]='\n';
|
||||||
|
header[header_len]=0;
|
||||||
|
|
||||||
|
if(2!=write(fd,&header_len,2)) laerror("cannot write 2 in numpy_write");
|
||||||
|
if(header_len!=write(fd,header,header_len)) laerror("cannot write 3 in numpy_write");
|
||||||
|
|
||||||
|
if(sizeof(T)*data.size()!= write(fd,&data[0],sizeof(T)*data.size())) laerror("cannot write 4 in numpy_write");
|
||||||
|
|
||||||
|
if(close(fd)) laerror("cannot close in numpy_write");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
//NOTE: we do not check numpy type matches T
|
||||||
|
template<typename T>
|
||||||
|
void Tensor<T>::numpy_read(const char *name)
|
||||||
|
{
|
||||||
|
int fd=open(name,O_LARGEFILE|O_RDONLY);
|
||||||
|
if(fd<0) laerror("cannot open in numpy_read");
|
||||||
|
char magic[]="\x93NUMPY\x01\x00";
|
||||||
|
char readmagic[8];
|
||||||
|
if(8!=read(fd,readmagic,8)) laerror("cannot read 1 in numpy_read");
|
||||||
|
if(memcmp(magic,readmagic,8)) laerror("bad magic or version != 1.0 in numpy_read");
|
||||||
|
|
||||||
|
int16_t header_len;
|
||||||
|
if(2!=read(fd,&header_len,2)) laerror("cannot read 2 in numpy_read");
|
||||||
|
|
||||||
|
char header[header_len];
|
||||||
|
if(header_len!=read(fd,header,header_len)) laerror("cannot read 3 in numpy_read");
|
||||||
|
|
||||||
|
char *p=strstr(header,"fortran_order");
|
||||||
|
if(!p) laerror("cannot find fortran_order in numpy_read");
|
||||||
|
bool fortranorder=false;
|
||||||
|
p=strstr(p,": ");
|
||||||
|
if(strncmp(p+2,"False",5)) fortranorder=true;
|
||||||
|
//std::cout <<"Fortran order "<<fortranorder<<std::endl;
|
||||||
|
|
||||||
|
std::list<int> dimslist;
|
||||||
|
p=strstr(p,"'shape': (");
|
||||||
|
if(!p) laerror("cannot find shape in numpy_read");
|
||||||
|
p+=strlen("'shape': (");
|
||||||
|
int d;
|
||||||
|
while(1==sscanf(p,"%d",&d))
|
||||||
|
{
|
||||||
|
if(fortranorder) dimslist.push_back(d); else dimslist.push_front(d);
|
||||||
|
while(isdigit(*p)||isspace(*p)) p++;
|
||||||
|
if(*p==')') break;
|
||||||
|
++p; //skip ,
|
||||||
|
}
|
||||||
|
|
||||||
|
NRVec<int> dims(dimslist);
|
||||||
|
if(dims.size()==0) laerror("zero-dimensional array in numpy_read");
|
||||||
|
|
||||||
|
shape.resize(dims.size());
|
||||||
|
for(int i=0; i<dims.size(); ++i)
|
||||||
|
{
|
||||||
|
shape[i].number=1;
|
||||||
|
shape[i].symmetry=0;
|
||||||
|
#ifndef LA_TENSOR_ZERO_OFFSET
|
||||||
|
shape[i].offset=0;
|
||||||
|
#endif
|
||||||
|
shape[i].range=dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
data.resize(calcsize());
|
||||||
|
calcrank();
|
||||||
|
canonicalize_shape();
|
||||||
|
if(sizeof(T)*data.size()!= read(fd,&data[0],sizeof(T)*data.size())) laerror("cannot read 4 in numpy_read");
|
||||||
|
if(close(fd)) laerror("cannot close in numpy_read");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template class Tensor<double>;
|
template class Tensor<double>;
|
||||||
|
|||||||
4
tensor.h
4
tensor.h
@@ -279,6 +279,10 @@ public:
|
|||||||
//subtensor - todo: more general versions
|
//subtensor - todo: more general versions
|
||||||
Tensor subtensor1(int i) const; //for one value of rightmost index
|
Tensor subtensor1(int i) const; //for one value of rightmost index
|
||||||
|
|
||||||
|
//numpy interface
|
||||||
|
void numpy_write(const char *name, const char *descr="<f8") const;
|
||||||
|
void numpy_read(const char *name);
|
||||||
|
|
||||||
//conversions from/to matrix and vector
|
//conversions from/to matrix and vector
|
||||||
explicit Tensor(const NRVec<T> &x);
|
explicit Tensor(const NRVec<T> &x);
|
||||||
explicit Tensor(const NRMat<T> &x, bool flat=false);
|
explicit Tensor(const NRMat<T> &x, bool flat=false);
|
||||||
|
|||||||
Reference in New Issue
Block a user