tensor: linear_transform implemented
This commit is contained in:
51
tensor.cc
51
tensor.cc
@@ -1785,7 +1785,6 @@ template<typename T>
|
||||
Tensor<T> Tensor<T>::inverseTucker(const NRVec<NRMat<T> > &x, bool inverseorder) const
|
||||
{
|
||||
if(rank()!=x.size()) laerror("input of inverseTucker does not match rank");
|
||||
NRVec<INDEXNAME> names_saved = names;
|
||||
Tensor<T> tmp(*this);
|
||||
Tensor<T> r;
|
||||
if(!is_flat()) laerror("inverseTucker only for flat tensors as produced by Tucker");
|
||||
@@ -1818,14 +1817,60 @@ for(int i=rank()-1; i>=0; --i)
|
||||
if(inverseorder)
|
||||
{
|
||||
NRPerm<int> p(rank()); p.identity();
|
||||
r.names=names_saved.permuted(p.reverse());
|
||||
r.names=names.permuted(p.reverse());
|
||||
}
|
||||
else
|
||||
r.names=names_saved;
|
||||
r.names=names;
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::linear_transform(const NRVec<NRMat<T> > &x) const
|
||||
{
|
||||
if(x.size()!=shape.size()) laerror("wrong number of transformation matrices in linear_transform");
|
||||
for(int i=0; i<shape.size(); ++i) if(x[i].ncols()!=shape[i].range) laerror("mismatch of transformation matrix size in linear_transform");
|
||||
|
||||
Tensor<T> tmp(*this);
|
||||
Tensor<T> r;
|
||||
|
||||
//do the groups in reverse order to preserve index ordering
|
||||
for(int g=shape.size()-1; g>=0; --g)
|
||||
{
|
||||
Tensor<T> mat(x[g],true); //flat tensor from a matrix
|
||||
for(int i=shape[g].number-1; i>=0; --i) //indices in the group in reverse order
|
||||
{
|
||||
r= tmp.contraction(indexposition(rank()-1,tmp.shape),mat,INDEX(0,0),(T)1,false,false); //always the last index
|
||||
if(i>0)
|
||||
{
|
||||
tmp=r;
|
||||
r.deallocate();
|
||||
}
|
||||
}
|
||||
//the group's indices are now individual ones leftmost in tensor r, restore group size and symmetry
|
||||
if(shape[g].number>1)
|
||||
{
|
||||
INDEXLIST il(shape[g].number);
|
||||
for(int i=0; i<shape[g].number; ++i)
|
||||
{
|
||||
il[i].group=i;
|
||||
il[i].index=0;
|
||||
}
|
||||
r = r.merge_indices(il,shape[g].symmetry);
|
||||
}
|
||||
if(g>0)
|
||||
{
|
||||
tmp=r;
|
||||
r.deallocate();
|
||||
}
|
||||
}
|
||||
|
||||
r.names=names;
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
int Tensor<T>::findflatindex(const INDEXNAME nam) const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user