tensor: subtensor1()

This commit is contained in:
2026-03-09 17:10:58 +01:00
parent 9b5372fcae
commit 302055db86
3 changed files with 44 additions and 1 deletions

22
t.cc
View File

@@ -1084,7 +1084,7 @@ NRMat<complex<double> > b=exp(a);
cout <<b;
}
if(1)
if(0)
{
int n;
double d;
@@ -4746,5 +4746,25 @@ cout <<aa;
cout <<"Error = "<<(ax-aa).norm()<<endl;
}
if(1)
{
int n=3;
NRVec<INDEXGROUP> s(4);
for(int i=0; i<4; ++i)
{
s[i].number=1;
s[i].symmetry=0;
s[i].offset=0;
s[i].range=n;
}
Tensor<double> t(s);
t.randomize(1.);
INDEXNAME list[4]={"i","j","k","l"};
t.names=list;
cout <<t;
Tensor<double> t1 = t.subtensor1(2);
cout <<t1;
}
}//main

View File

@@ -2496,6 +2496,26 @@ return false;
}
template<typename T>
Tensor<T> Tensor<T>::subtensor1(int i) const
{
int ind=shape.size();
if(ind==0) laerror("subtensor of a scalar");
--ind;
if(shape[ind].number>1) laerror("last index must be standalone in subtensor1");
i -= shape[ind].offset;
if(i<0||i>=shape[ind].range) laerror("index out of range in subtensor1");
if(ind==0) return Tensor(data[i]); //results is a scalar
NRVec<INDEXGROUP> newshape = shape.subvector(0,ind-1);
Tensor<T> r(newshape);
memcpy(&r.data[0],&data[i*cumsizes[ind]],cumsizes[ind]*sizeof(T));
if(is_named()) r.names = names.subvector(0,ind-1);
return r;
}
template class Tensor<double>;
template class Tensor<std::complex<double> >;
template std::ostream & operator<<(std::ostream &s, const Tensor<double> &x);

View File

@@ -276,6 +276,9 @@ public:
Tensor(int xrank, const NRVec<INDEXGROUP> &xshape, const NRVec<LA_largeindex> &xgroupsizes, const NRVec<LA_largeindex> xcumsizes, const NRVec<T> &xdata) : myrank(xrank), shape(xshape), groupsizes(xgroupsizes), cumsizes(xcumsizes), data(xdata) {};
Tensor(int xrank, const NRVec<INDEXGROUP> &xshape, const NRVec<LA_largeindex> &xgroupsizes, const NRVec<LA_largeindex> xcumsizes, const NRVec<T> &xdata, const NRVec<INDEXNAME> &xnames) : myrank(xrank), shape(xshape), groupsizes(xgroupsizes), cumsizes(xcumsizes), data(xdata), names(xnames) {};
//subtensor - todo: more general versions
Tensor subtensor1(int i) const; //for one value of rightmost index
//conversions from/to matrix and vector
explicit Tensor(const NRVec<T> &x);
explicit Tensor(const NRMat<T> &x, bool flat=false);