small improvements in tensor/tucker
This commit is contained in:
38
tensor.cc
38
tensor.cc
@@ -802,10 +802,30 @@ template<typename T>
|
||||
Tensor<T> Tensor<T>::flatten(int group) const
|
||||
{
|
||||
if(group>=shape.size()) laerror("too high group number in flatten");
|
||||
if(is_flat()) return *this;
|
||||
if(is_flat())
|
||||
{
|
||||
if(has_symmetry()) //get rid of formal symemtry
|
||||
{
|
||||
Tensor<T> r(*this);
|
||||
r.shape.copyonwrite();
|
||||
for(int g=0; g<r.shape.size(); ++g) r.shape[g].symmetry=0;
|
||||
return r;
|
||||
}
|
||||
else
|
||||
return *this;
|
||||
}
|
||||
if(group>=0) //single group
|
||||
{
|
||||
if(shape[group].number==1) return *this;
|
||||
if(shape[group].number==1)
|
||||
{
|
||||
if(shape[group].symmetry==0) return *this;
|
||||
else
|
||||
{
|
||||
Tensor<T> r(*this);
|
||||
r.shape[group].symmetry=0;
|
||||
return r;
|
||||
}
|
||||
}
|
||||
if(shape[group].symmetry==0)
|
||||
{
|
||||
Tensor<T> r(*this);
|
||||
@@ -816,7 +836,11 @@ if(group>=0) //single group
|
||||
if(group<0 && !is_compressed())
|
||||
{
|
||||
Tensor<T> r(*this);
|
||||
for(int g=0; g<shape.size(); ++g) if(shape[g].number>1) r.split_index_group(g);
|
||||
for(int g=0; g<shape.size(); ++g)
|
||||
{
|
||||
if(shape[g].number>1) r.split_index_group(g);
|
||||
}
|
||||
for(int g=0; g<r.shape.size(); ++g) r.shape[g].symmetry=0;
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -1285,7 +1309,7 @@ return r;
|
||||
|
||||
|
||||
//NOTE: Tucker of rank=2 is inherently inefficient - result is a diagonal tensor stored in full and 2 calls to SVD
|
||||
//we could avoid the second SVD, but the wasteful storage and erconstruction would remain
|
||||
//we could avoid the second SVD, but the wasteful storage and reconstruction would remain
|
||||
//
|
||||
template<typename T>
|
||||
NRVec<NRMat<T> > Tensor<T>::Tucker(typename LA_traits<T>::normtype thr, bool inverseorder)
|
||||
@@ -1314,9 +1338,9 @@ for(int i=0; i<r; ++i)
|
||||
NRMat<T> um;
|
||||
NRVec<indexgroup> ushape;
|
||||
{
|
||||
Tensor<T> u=unwind_index(I);
|
||||
ushape=u.shape; ushape.copyonwrite();
|
||||
um=u.matrix();
|
||||
Tensor<T> uu=unwind_index(I);
|
||||
ushape=uu.shape; //ushape.copyonwrite(); should not be needed
|
||||
um=uu.matrix();
|
||||
}
|
||||
int mini=um.nrows(); if(um.ncols()<mini) mini=um.ncols(); //compact SVD, expect descendingly sorted values
|
||||
NRMat<T> u(um.nrows(),mini),vt(mini,um.ncols());
|
||||
|
||||
Reference in New Issue
Block a user