tensor: scalar as rank==0 special case
This commit is contained in:
18
tensor.cc
18
tensor.cc
@@ -34,7 +34,7 @@ int r=0;
|
||||
for(int i=0; i<shape.size(); ++i)
|
||||
{
|
||||
const indexgroup *sh = &(* const_cast<const NRVec<indexgroup> *>(&shape))[i];
|
||||
if(sh->number<=0) laerror("empty index group"); //we do not support scalar as a trivial case
|
||||
if(sh->number<=0) laerror("empty index group"); //scalar must have shape.size()==0, not empty index group
|
||||
r+=sh->number;
|
||||
}
|
||||
myrank=r;
|
||||
@@ -46,7 +46,7 @@ return r;
|
||||
template<typename T>
|
||||
LA_largeindex Tensor<T>::calcsize()
|
||||
{
|
||||
if(shape.size()==0) laerror("tensor must have rank at least 1");
|
||||
if(shape.size()==0) return 1; //scalar
|
||||
groupsizes.resize(shape.size());
|
||||
cumsizes.resize(shape.size());
|
||||
LA_largeindex s=1;
|
||||
@@ -578,6 +578,7 @@ std::ostream & operator<<(std::ostream &s, const Tensor<T> &x)
|
||||
{
|
||||
s<<x.shape;
|
||||
s<<x.names;
|
||||
if(x.rank()==0) {s<<x.data[0]; return s;}
|
||||
sout= &s;
|
||||
x.constloopover(&outputcallback<T>);
|
||||
return s;
|
||||
@@ -1383,25 +1384,26 @@ loopover(permutationalgebra_callback2);
|
||||
template<typename T>
|
||||
void Tensor<T>::split_index_group(int group)
|
||||
{
|
||||
const indexgroup *sh = &(* const_cast<const NRVec<indexgroup> *>(&shape))[0];
|
||||
if(group<0||group >= shape.size()) laerror("illegal index group number");
|
||||
if(shape[group].number==1) return; //nothing to split
|
||||
if(shape[group].symmetry!=0) laerror("only non-symmetric index group can be splitted, use flatten instead");
|
||||
if(sh[group].number==1) return; //nothing to split
|
||||
if(sh[group].symmetry!=0) laerror("only non-symmetric index group can be splitted, use flatten instead");
|
||||
|
||||
NRVec<indexgroup> newshape(shape.size()+shape[group].number-1);
|
||||
NRVec<indexgroup> newshape(shape.size()+sh[group].number-1);
|
||||
int gg=0;
|
||||
for(int g=0; g<shape.size(); ++g)
|
||||
{
|
||||
if(g==group)
|
||||
{
|
||||
for(int i=0; i<shape[group].number; ++i)
|
||||
for(int i=0; i<sh[group].number; ++i)
|
||||
{
|
||||
newshape[gg] = shape[group];
|
||||
newshape[gg] = sh[group];
|
||||
newshape[gg].number = 1;
|
||||
gg++;
|
||||
}
|
||||
}
|
||||
else
|
||||
newshape[gg++] = shape[g];
|
||||
newshape[gg++] = sh[g];
|
||||
}
|
||||
|
||||
shape=newshape;
|
||||
|
||||
Reference in New Issue
Block a user