tensor: bugfix in unwind_index and linear_transform of single group
This commit is contained in:
52
t.cc
52
t.cc
@@ -4534,7 +4534,7 @@ for(int i=0; i<m; ++i)
|
||||
|
||||
}
|
||||
|
||||
if(1)
|
||||
if(0)
|
||||
{
|
||||
int n,m;
|
||||
bool which;
|
||||
@@ -4570,6 +4570,56 @@ for(int i=0; i<m; ++i)
|
||||
cout <<endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if(1)
|
||||
{
|
||||
int r,n,sym;
|
||||
cin>>r>>n>>sym;
|
||||
NRVec<INDEXGROUP> shape(3);
|
||||
shape[0].number=2;
|
||||
shape[0].symmetry=0;
|
||||
shape[0].range=n+1;
|
||||
shape[0].offset=0;
|
||||
|
||||
shape[1].number=r;
|
||||
shape[1].symmetry= sym;
|
||||
shape[1].range=n;
|
||||
shape[1].offset=0;
|
||||
|
||||
shape[2].number=2;
|
||||
shape[2].symmetry=0;
|
||||
shape[2].range=n+2;
|
||||
shape[2].offset=0;
|
||||
|
||||
|
||||
Tensor<double> x(shape); x.randomize(1.);
|
||||
x.defaultnames();
|
||||
cout <<"x= "<<x.shape << " "<<x.names<<endl;
|
||||
|
||||
NRVec<INDEXGROUP> yshape(2);
|
||||
yshape[0].number=1;
|
||||
yshape[0].symmetry=0;
|
||||
yshape[0].range=n;
|
||||
yshape[0].offset=0;
|
||||
|
||||
yshape[1].number=1;
|
||||
yshape[1].symmetry= 0;
|
||||
yshape[1].range=n+3;
|
||||
yshape[1].offset=0;
|
||||
|
||||
Tensor<double> y(yshape); y.randomize(1.);
|
||||
|
||||
INDEX posit(1,0);
|
||||
//Tensor<double> z=x.unwind_index(1,1);
|
||||
//Tensor<double> z=x.contraction(INDEX(1,1),y,INDEX(0,0),1,false,false);
|
||||
Tensor<double> z=x.linear_transform(1,y);
|
||||
|
||||
cout <<z.shape;
|
||||
|
||||
//check
|
||||
|
||||
|
||||
}
|
||||
|
||||
}//main
|
||||
|
||||
66
tensor.cc
66
tensor.cc
@@ -713,6 +713,9 @@ return q;
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::permute_index_groups(const NRPerm<int> &p) const
|
||||
{
|
||||
if(p.size()!=shape.size()) laerror("permutation size mismatch in permute_index_groups");
|
||||
if(p.is_identity()) return *this;
|
||||
|
||||
//std::cout <<"permute_index_groups permutation = "<<p<<std::endl;
|
||||
NRVec<INDEXGROUP> newshape=shape.permuted(p,true);
|
||||
//std::cout <<"permute_index_groups newshape = "<<newshape<<std::endl;
|
||||
@@ -836,6 +839,7 @@ if(group==0 && index==0 && shape[0].symmetry==0) //formal split of 1 index witho
|
||||
}
|
||||
if(shape[group].number==1) return unwind_index_group(group); //single index in the group
|
||||
|
||||
|
||||
//general case - recalculate the shape and allocate the new tensor
|
||||
NRVec<INDEXGROUP> newshape(shape.size()+1);
|
||||
newshape[0].number=1;
|
||||
@@ -856,7 +860,10 @@ for(int i=0; i<shape.size(); ++i)
|
||||
--newshape[i+1].number;
|
||||
flatindex += index;
|
||||
}
|
||||
else flatindex += shape[i].number;
|
||||
else
|
||||
{
|
||||
if(i<group) flatindex += shape[i].number;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1840,6 +1847,10 @@ for(int g=shape.size()-1; g>=0; --g)
|
||||
Tensor<T> mat(x[g],true); //flat tensor from a matrix
|
||||
for(int i=shape[g].number-1; i>=0; --i) //indices in the group in reverse order
|
||||
{
|
||||
#ifdef LA_TENSOR_INDEXPOSITION
|
||||
//what should we do with index position?
|
||||
//either set upperindex of mat appropriately, or request ignoring that in contractions?
|
||||
#endif
|
||||
r= tmp.contraction(indexposition(rank()-1,tmp.shape),mat,INDEX(0,0),(T)1,false,false); //always the last index
|
||||
if(i>0)
|
||||
{
|
||||
@@ -1870,6 +1881,59 @@ return r;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::linear_transform(const int g, const Tensor<T> &x) const
|
||||
{
|
||||
if(g<0||g>=shape.size()) laerror("wrong index group number in linear_transform");
|
||||
if(x.rank()!=2 || x.shape.size()!=2 || x.shape[0].number!=1||x.shape[1].number!=1) laerror("wrong tensor shape for linear_transform");
|
||||
if(x.shape[0].range!=shape[g].range) laerror("index range mismatch in linear_transform");
|
||||
|
||||
Tensor<T> tmp(*this);
|
||||
Tensor<T> r;
|
||||
|
||||
|
||||
//contract all indices in the reverse order
|
||||
int gnow=g;
|
||||
for(int i=shape[g].number-1; i>=0; --i) //indices in the group in reverse order
|
||||
{
|
||||
r= tmp.contraction(INDEX(gnow,i),x,INDEX(0,0),(T)1,false,false);
|
||||
++gnow; //new group number in r, one index was added left
|
||||
if(i>0)
|
||||
{
|
||||
tmp=r;
|
||||
r.deallocate();
|
||||
}
|
||||
}
|
||||
//the group's indices are now individual ones leftmost in tensor r, restore group size and symmetry
|
||||
if(shape[g].number>1)
|
||||
{
|
||||
INDEXLIST il(shape[g].number);
|
||||
for(int i=0; i<shape[g].number; ++i)
|
||||
{
|
||||
il[i].group=i;
|
||||
il[i].index=0;
|
||||
}
|
||||
r = r.merge_indices(il,shape[g].symmetry);
|
||||
}
|
||||
|
||||
//permute the group (now 0) to its original position
|
||||
if(g!=0)
|
||||
{
|
||||
NRPerm<int> p(r.shape.size());
|
||||
for(int i=0; i<g; ++i) p[i+1] = (i+1)+1;
|
||||
p[g+1]=1;
|
||||
for(int i=g+1; i<r.shape.size();++i) p[i+1] = i+1;
|
||||
//std::cout<<"perm "<<p;
|
||||
r=r.permute_index_groups(p);
|
||||
}
|
||||
|
||||
//preserve names
|
||||
r.names=names;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
int Tensor<T>::findflatindex(const INDEXNAME nam) const
|
||||
|
||||
2
tensor.h
2
tensor.h
@@ -421,6 +421,8 @@ public:
|
||||
NRVec<NRMat<T> > Tucker(typename LA_traits<T>::normtype thr=1e-12, bool inverseorder=false); //HOSVD-Tucker decomposition, return core tensor in *this, flattened
|
||||
Tensor inverseTucker(const NRVec<NRMat<T> > &x, bool inverseorder=false) const; //rebuild the original tensor from Tucker
|
||||
Tensor linear_transform(const NRVec<NRMat<T> > &x) const; //linear transform by a different matrix per each index group, preserving group symmetries
|
||||
Tensor linear_transform(const int g, const Tensor<T> &x) const; //linear transform a single group, preserve its position and symmetry, x must be rank-2 flat tensor (this is useful for raising/lowering indices)
|
||||
Tensor linear_transform(const int g, const NRMat<T> &x) const {Tensor<T> mat(x,true); return linear_transform(g,mat);}; //linear transform a single group, preserve its position and symmetry
|
||||
};
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user