tensor: scalar as rank==0 special case
This commit is contained in:
37
t.cc
37
t.cc
@@ -3362,7 +3362,7 @@ for(int i=0; i<n; ++i)
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
if(1)
|
if(0)
|
||||||
{
|
{
|
||||||
int n=5;
|
int n=5;
|
||||||
INDEXGROUP ag;
|
INDEXGROUP ag;
|
||||||
@@ -3959,6 +3959,41 @@ for(int k=0; k<n; ++k)
|
|||||||
cout<<t2;
|
cout<<t2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(0)
|
||||||
|
{
|
||||||
|
//scalar as trivial tensor rank 0 case
|
||||||
|
Tensor<double> t(2.),u(3.);
|
||||||
|
cout <<t<<endl;
|
||||||
|
SUPERINDEX I;
|
||||||
|
FLATINDEX J;
|
||||||
|
cout <<t(I)<<endl;
|
||||||
|
cout <<u(J)<<endl;
|
||||||
|
cout <<t+u<<endl;
|
||||||
|
cout <<t*u<<endl;
|
||||||
|
cout <<t.dot(u)<<endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if(1)
|
||||||
|
{
|
||||||
|
int r,n;
|
||||||
|
cin>>r>>n;
|
||||||
|
INDEXGROUP shape;
|
||||||
|
{
|
||||||
|
shape.number=r;
|
||||||
|
//shape.symmetry= 0;
|
||||||
|
shape.symmetry= -1;
|
||||||
|
shape.range=n;
|
||||||
|
shape.offset=0;
|
||||||
|
}
|
||||||
|
Tensor<double> x(shape);
|
||||||
|
x.randomize(1.);
|
||||||
|
cout<<x;
|
||||||
|
Tensor<double> xf=x.flatten();
|
||||||
|
cout <<xf;
|
||||||
|
|
||||||
|
cout <<x.dot(x) <<" "<< xf.dot(xf)<<endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}//main
|
}//main
|
||||||
|
|||||||
18
tensor.cc
18
tensor.cc
@@ -34,7 +34,7 @@ int r=0;
|
|||||||
for(int i=0; i<shape.size(); ++i)
|
for(int i=0; i<shape.size(); ++i)
|
||||||
{
|
{
|
||||||
const indexgroup *sh = &(* const_cast<const NRVec<indexgroup> *>(&shape))[i];
|
const indexgroup *sh = &(* const_cast<const NRVec<indexgroup> *>(&shape))[i];
|
||||||
if(sh->number<=0) laerror("empty index group"); //we do not support scalar as a trivial case
|
if(sh->number<=0) laerror("empty index group"); //scalar must have shape.size()==0, not empty index group
|
||||||
r+=sh->number;
|
r+=sh->number;
|
||||||
}
|
}
|
||||||
myrank=r;
|
myrank=r;
|
||||||
@@ -46,7 +46,7 @@ return r;
|
|||||||
template<typename T>
|
template<typename T>
|
||||||
LA_largeindex Tensor<T>::calcsize()
|
LA_largeindex Tensor<T>::calcsize()
|
||||||
{
|
{
|
||||||
if(shape.size()==0) laerror("tensor must have rank at least 1");
|
if(shape.size()==0) return 1; //scalar
|
||||||
groupsizes.resize(shape.size());
|
groupsizes.resize(shape.size());
|
||||||
cumsizes.resize(shape.size());
|
cumsizes.resize(shape.size());
|
||||||
LA_largeindex s=1;
|
LA_largeindex s=1;
|
||||||
@@ -578,6 +578,7 @@ std::ostream & operator<<(std::ostream &s, const Tensor<T> &x)
|
|||||||
{
|
{
|
||||||
s<<x.shape;
|
s<<x.shape;
|
||||||
s<<x.names;
|
s<<x.names;
|
||||||
|
if(x.rank()==0) {s<<x.data[0]; return s;}
|
||||||
sout= &s;
|
sout= &s;
|
||||||
x.constloopover(&outputcallback<T>);
|
x.constloopover(&outputcallback<T>);
|
||||||
return s;
|
return s;
|
||||||
@@ -1383,25 +1384,26 @@ loopover(permutationalgebra_callback2);
|
|||||||
template<typename T>
|
template<typename T>
|
||||||
void Tensor<T>::split_index_group(int group)
|
void Tensor<T>::split_index_group(int group)
|
||||||
{
|
{
|
||||||
|
const indexgroup *sh = &(* const_cast<const NRVec<indexgroup> *>(&shape))[0];
|
||||||
if(group<0||group >= shape.size()) laerror("illegal index group number");
|
if(group<0||group >= shape.size()) laerror("illegal index group number");
|
||||||
if(shape[group].number==1) return; //nothing to split
|
if(sh[group].number==1) return; //nothing to split
|
||||||
if(shape[group].symmetry!=0) laerror("only non-symmetric index group can be splitted, use flatten instead");
|
if(sh[group].symmetry!=0) laerror("only non-symmetric index group can be splitted, use flatten instead");
|
||||||
|
|
||||||
NRVec<indexgroup> newshape(shape.size()+shape[group].number-1);
|
NRVec<indexgroup> newshape(shape.size()+sh[group].number-1);
|
||||||
int gg=0;
|
int gg=0;
|
||||||
for(int g=0; g<shape.size(); ++g)
|
for(int g=0; g<shape.size(); ++g)
|
||||||
{
|
{
|
||||||
if(g==group)
|
if(g==group)
|
||||||
{
|
{
|
||||||
for(int i=0; i<shape[group].number; ++i)
|
for(int i=0; i<sh[group].number; ++i)
|
||||||
{
|
{
|
||||||
newshape[gg] = shape[group];
|
newshape[gg] = sh[group];
|
||||||
newshape[gg].number = 1;
|
newshape[gg].number = 1;
|
||||||
gg++;
|
gg++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
newshape[gg++] = shape[g];
|
newshape[gg++] = sh[g];
|
||||||
}
|
}
|
||||||
|
|
||||||
shape=newshape;
|
shape=newshape;
|
||||||
|
|||||||
4
tensor.h
4
tensor.h
@@ -40,7 +40,6 @@
|
|||||||
#include "miscfunc.h"
|
#include "miscfunc.h"
|
||||||
|
|
||||||
//TODO:
|
//TODO:
|
||||||
//@@@@@@how to handle contractions yielding a scalar - special treatment, support special case of rank=0 tensor?
|
|
||||||
//@@@contraction inside one tensor - compute resulting shape, loopover the shape, create index into the original tensor + loop over the contr. index, do the summation, store result
|
//@@@contraction inside one tensor - compute resulting shape, loopover the shape, create index into the original tensor + loop over the contr. index, do the summation, store result
|
||||||
//@@@ will need to store vector of INDEX to the original tensor for the result's flatindex
|
//@@@ will need to store vector of INDEX to the original tensor for the result's flatindex
|
||||||
//@@@ will not be particularly efficient
|
//@@@ will not be particularly efficient
|
||||||
@@ -187,7 +186,8 @@ public:
|
|||||||
SUPERINDEX inverse_index(LA_largeindex s) const; //inefficient, but possible if needed
|
SUPERINDEX inverse_index(LA_largeindex s) const; //inefficient, but possible if needed
|
||||||
|
|
||||||
//constructors
|
//constructors
|
||||||
Tensor() : myrank(0) {};
|
Tensor() : myrank(-1) {};
|
||||||
|
explicit Tensor(const T &x) : myrank(0), data(1) {data[0]=x;}; //scalar
|
||||||
Tensor(const NRVec<indexgroup> &s) : shape(s) { data.resize(calcsize()); calcrank();}; //general tensor
|
Tensor(const NRVec<indexgroup> &s) : shape(s) { data.resize(calcsize()); calcrank();}; //general tensor
|
||||||
Tensor(const NRVec<indexgroup> &s, const NRVec<INDEXNAME> &newnames) : shape(s), names(newnames) { data.resize(calcsize()); calcrank(); if(names.size()!=myrank && names.size()!=0) laerror("bad number of index names");}; //general tensor
|
Tensor(const NRVec<indexgroup> &s, const NRVec<INDEXNAME> &newnames) : shape(s), names(newnames) { data.resize(calcsize()); calcrank(); if(names.size()!=myrank && names.size()!=0) laerror("bad number of index names");}; //general tensor
|
||||||
Tensor(const NRVec<indexgroup> &s, const NRVec<T> &mydata) : shape(s) { LA_largeindex dsize=calcsize(); calcrank(); if(mydata.size()!=dsize) laerror("inconsistent data size with shape"); data=mydata;}
|
Tensor(const NRVec<indexgroup> &s, const NRVec<T> &mydata) : shape(s) { LA_largeindex dsize=calcsize(); calcrank(); if(mydata.size()!=dsize) laerror("inconsistent data size with shape"); data=mydata;}
|
||||||
|
|||||||
Reference in New Issue
Block a user