Tucked tested on compressed tensors, flattening implemented
This commit is contained in:
91
tensor.cc
91
tensor.cc
@@ -516,7 +516,7 @@ template<typename T>
|
||||
void Tensor<T>::grouploopover(void (*callback)(const GROUPINDEX &, T *))
|
||||
{
|
||||
GROUPINDEX I(shape.size());
|
||||
T *pp=&data[0];
|
||||
T *pp= &data[0];
|
||||
loopovergroups(*this,shape.size()-1,&pp,I,callback);
|
||||
}
|
||||
|
||||
@@ -649,7 +649,7 @@ for(int i=0; i<shape.size(); ++i)
|
||||
else flatindex += shape[i].number;
|
||||
}
|
||||
|
||||
std::cout <<"unwind new shape = "<<newshape<<std::endl;
|
||||
//std::cout <<"unwind new shape = "<<newshape<<std::endl;
|
||||
|
||||
Tensor<T> r(newshape);
|
||||
if(r.rank()!=rank()) laerror("internal error 2 in unwind_index");
|
||||
@@ -670,7 +670,7 @@ if(!indexperm.is_valid())
|
||||
laerror("internal error 3 in unwind_index");
|
||||
}
|
||||
|
||||
std::cout <<"unwind permutation = "<<indexperm<<std::endl;
|
||||
//std::cout <<"unwind permutation = "<<indexperm<<std::endl;
|
||||
|
||||
//loop recursively and do the unwinding
|
||||
help_tt<T> = this;
|
||||
@@ -680,6 +680,79 @@ return r;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static void flatten_callback(const SUPERINDEX &I, T *v)
|
||||
{
|
||||
FLATINDEX J = superindex2flat(I);
|
||||
//std::cout <<"TEST flatten_callback: from "<<JP<<" TO "<<J<<std::endl;
|
||||
*v = (*help_tt<T>)(J); //rhs operator() generates the redundant elements for the unwinded lhs tensor
|
||||
}
|
||||
//
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::flatten(int group) const
|
||||
{
|
||||
if(group>=shape.size()) laerror("too high group number in flatten");
|
||||
if(is_flat()) return *this;
|
||||
if(group>=0) //single group
|
||||
{
|
||||
if(shape[group].number==1) return *this;
|
||||
if(shape[group].symmetry==0)
|
||||
{
|
||||
Tensor<T> r(*this);
|
||||
r.split_index_group(group);
|
||||
return r;
|
||||
}
|
||||
}
|
||||
if(group<0 && !is_compressed())
|
||||
{
|
||||
Tensor<T> r(*this);
|
||||
for(int g=0; g<shape.size(); ++g) if(shape[g].number>1) r.split_index_group(g);
|
||||
return r;
|
||||
}
|
||||
|
||||
//general case
|
||||
int newsize;
|
||||
if(group<0) newsize=rank();
|
||||
else newsize=shape.size()+shape[group].number-1;
|
||||
|
||||
//build new shape
|
||||
NRVec<indexgroup> newshape(newsize);
|
||||
int gg=0;
|
||||
for(int g=0; g<shape.size(); ++g)
|
||||
{
|
||||
if((group<0 ||g==group) && shape[g].number>1) //flatten this group
|
||||
{
|
||||
for(int i=0; i<shape[g].number; ++i)
|
||||
{
|
||||
newshape[gg].symmetry=0;
|
||||
newshape[gg].number=1;
|
||||
newshape[gg].range=shape[g].range;
|
||||
#ifndef LA_TENSOR_ZERO_OFFSET
|
||||
newshape[gg].offset = shape[g].offset;
|
||||
#endif
|
||||
gg++;
|
||||
}
|
||||
}
|
||||
else //preserve this group
|
||||
{
|
||||
newshape[gg++] = shape[g];
|
||||
}
|
||||
}
|
||||
|
||||
std::cout <<"Flatten new shape = "<<newshape<<std::endl;
|
||||
|
||||
//decompress the tensor data
|
||||
Tensor<T> r(newshape);
|
||||
help_tt<T> = this;
|
||||
r.loopover(flatten_callback);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::unwind_indices(const INDEXLIST &il) const
|
||||
{
|
||||
@@ -1012,7 +1085,7 @@ void Tensor<T>::split_index_group(int group)
|
||||
{
|
||||
if(group<0||group >= shape.size()) laerror("illegal index group number");
|
||||
if(shape[group].number==1) return; //nothing to split
|
||||
if(shape[group].symmetry!=0) laerror("only non-symmetric index group can be splitted");
|
||||
if(shape[group].symmetry!=0) laerror("only non-symmetric index group can be splitted, use flatten instead");
|
||||
|
||||
NRVec<indexgroup> newshape(shape.size()+shape[group].number-1);
|
||||
int gg=0;
|
||||
@@ -1133,16 +1206,19 @@ for(int i=0; i<r; ++i)
|
||||
//std::cout << "resulting U "<<u<<std::endl;
|
||||
//std::cout << "resulting W "<<w<<std::endl;
|
||||
//std::cout << "resulting VT "<<vt<<std::endl;
|
||||
int umnr=um.nrows();
|
||||
int umnc=um.ncols();
|
||||
um.resize(0,0); //deallocate
|
||||
int preserve=mini;
|
||||
for(int k=0; k<mini; ++k) if(w[k]<thr) {preserve=k; break;}
|
||||
if(preserve==0) laerror("singular tensor in Tucker decomposition");
|
||||
NRMat<T> umnew;
|
||||
//std::cout <<"TEST "<<i<<" mini preserve "<<mini<<" "<<preserve<<std::endl;
|
||||
if(preserve<mini)
|
||||
{
|
||||
vt=vt.submatrix(0,preserve-1,0,um.ncols()-1);
|
||||
vt=vt.submatrix(0,preserve-1,0,umnc-1);
|
||||
w=w.subvector(0,preserve-1);
|
||||
umnew=u.submatrix(0,um.nrows()-1,0,preserve-1);
|
||||
umnew=u.submatrix(0,umnr-1,0,preserve-1);
|
||||
}
|
||||
else umnew=u;
|
||||
ret[(inverseorder? r-i-1 : i)]=vt.transpose(true);
|
||||
@@ -1197,6 +1273,9 @@ else
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
template class Tensor<double>;
|
||||
template class Tensor<std::complex<double> >;
|
||||
template std::ostream & operator<<(std::ostream &s, const Tensor<double> &x);
|
||||
|
||||
Reference in New Issue
Block a user