bugfix unwind_index; Tucker works
This commit is contained in:
167
tensor.cc
167
tensor.cc
@@ -34,7 +34,7 @@ int r=0;
|
||||
for(int i=0; i<shape.size(); ++i)
|
||||
{
|
||||
const indexgroup *sh = &(* const_cast<const NRVec<indexgroup> *>(&shape))[i];
|
||||
if(sh->number==0) laerror("empty index group");
|
||||
if(sh->number<=0) laerror("empty index group"); //we do not support scalar as a trivial case
|
||||
r+=sh->number;
|
||||
}
|
||||
myrank=r;
|
||||
@@ -46,6 +46,7 @@ return r;
|
||||
template<typename T>
|
||||
LA_largeindex Tensor<T>::calcsize()
|
||||
{
|
||||
if(shape.size()==0) laerror("tensor must have rank at least 1");
|
||||
groupsizes.resize(shape.size());
|
||||
cumsizes.resize(shape.size());
|
||||
LA_largeindex s=1;
|
||||
@@ -330,11 +331,11 @@ calcsize();
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Tensor<T>::Tensor(const NRMat<T> &x)
|
||||
Tensor<T>::Tensor(const NRMat<T> &x, bool flat)
|
||||
: data(&x(0,0),x.nrows()*x.ncols())
|
||||
{
|
||||
myrank=2;
|
||||
if(x.nrows()==x.ncols())
|
||||
if(x.nrows()==x.ncols() && !flat)
|
||||
{
|
||||
shape.resize(1);
|
||||
shape[0].number=2;
|
||||
@@ -542,7 +543,9 @@ help_t<T>->data[target] = *v;
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::permute_index_groups(const NRPerm<int> &p) const
|
||||
{
|
||||
NRVec<indexgroup> newshape=shape.permuted(p);
|
||||
//std::cout <<"permute_index_groups permutation = "<<p<<std::endl;
|
||||
NRVec<indexgroup> newshape=shape.permuted(p,true);
|
||||
//std::cout <<"permute_index_groups newshape = "<<newshape<<std::endl;
|
||||
Tensor<T> r(newshape);
|
||||
|
||||
//prepare statics for the callback
|
||||
@@ -568,13 +571,8 @@ for(int i=0; i<I.size(); ++i)
|
||||
return J;
|
||||
}
|
||||
|
||||
int flatposition(const INDEX &i, const NRVec<indexgroup> &shape)
|
||||
{
|
||||
int ii=0;
|
||||
for(int g=0; g<i.group; ++g) ii+= shape[g].number;
|
||||
ii += i.index;
|
||||
return ii;
|
||||
}
|
||||
int flatposition(const INDEX &i, const NRVec<indexgroup> &shape)
|
||||
{return flatposition(i.group,i.index,shape);}
|
||||
|
||||
int flatposition(int group, int index, const NRVec<indexgroup> &shape)
|
||||
{
|
||||
@@ -584,12 +582,26 @@ ii += index;
|
||||
return ii;
|
||||
}
|
||||
|
||||
INDEX indexposition(int flatindex, const NRVec<indexgroup> &shape)
|
||||
{
|
||||
INDEX I={0,0};
|
||||
if(flatindex<0) laerror("illegal index in indexposition");
|
||||
for(int g=0; g<shape.size(); ++g)
|
||||
{
|
||||
I.group=g;
|
||||
if(flatindex<shape[g].number) {I.index=flatindex; return I;}
|
||||
flatindex-=shape[g].number;
|
||||
}
|
||||
laerror("flatindex out of range");
|
||||
return I;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
static void unwind_callback(const SUPERINDEX &I, T *v)
|
||||
{
|
||||
FLATINDEX J = superindex2flat(I);
|
||||
FLATINDEX JP = J.permuted(*help_p,true);
|
||||
FLATINDEX JP = J.permuted(*help_p,false);
|
||||
//std::cout <<"TEST unwind_callback: from "<<JP<<" TO "<<J<<std::endl;
|
||||
*v = (*help_tt<T>)(JP); //rhs operator() generates the redundant elements for the unwinded lhs tensor
|
||||
}
|
||||
@@ -637,6 +649,8 @@ for(int i=0; i<shape.size(); ++i)
|
||||
else flatindex += shape[i].number;
|
||||
}
|
||||
|
||||
std::cout <<"unwind new shape = "<<newshape<<std::endl;
|
||||
|
||||
Tensor<T> r(newshape);
|
||||
if(r.rank()!=rank()) laerror("internal error 2 in unwind_index");
|
||||
|
||||
@@ -656,6 +670,8 @@ if(!indexperm.is_valid())
|
||||
laerror("internal error 3 in unwind_index");
|
||||
}
|
||||
|
||||
std::cout <<"unwind permutation = "<<indexperm<<std::endl;
|
||||
|
||||
//loop recursively and do the unwinding
|
||||
help_tt<T> = this;
|
||||
help_p = &indexperm;
|
||||
@@ -782,7 +798,7 @@ return r;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0, bool conjugate=false) //R(nn,mm) = A(nn,kk) * B^T(mm,kk)
|
||||
static void auxmatmult(int nn, int mm, int kk, T *r, const T *a, const T *b, T alpha=1, T beta=0, bool conjugate=false) //R(nn,mm) = A(nn,kk) * B^T(mm,kk)
|
||||
{
|
||||
for(int i=0; i<nn; ++i) for(int j=0; j<mm; ++j)
|
||||
{
|
||||
@@ -793,13 +809,13 @@ for(int i=0; i<nn; ++i) for(int j=0; j<mm; ++j)
|
||||
|
||||
|
||||
template<>
|
||||
void auxmatmult<double>(int nn, int mm, int kk, double *r, double *a, double *b, double alpha, double beta, bool conjugate)
|
||||
void auxmatmult<double>(int nn, int mm, int kk, double *r, const double *a, const double *b, double alpha, double beta, bool conjugate)
|
||||
{
|
||||
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, alpha, a, kk, b, kk, beta, r, mm);
|
||||
}
|
||||
|
||||
template<>
|
||||
void auxmatmult<std::complex<double> >(int nn, int mm, int kk, std::complex<double> *r, std::complex<double> *a, std::complex<double> *b, std::complex<double> alpha, std::complex<double> beta, bool conjugate)
|
||||
void auxmatmult<std::complex<double> >(int nn, int mm, int kk, std::complex<double> *r, const std::complex<double> *a, const std::complex<double> *b, std::complex<double> alpha, std::complex<double> beta, bool conjugate)
|
||||
{
|
||||
cblas_zgemm(CblasRowMajor, CblasNoTrans, (conjugate?CblasConjTrans:CblasTrans), nn, mm, kk, &alpha, a, kk, b, kk, &beta, r, mm);
|
||||
}
|
||||
@@ -823,9 +839,8 @@ if(rhsindex<0||rhsindex>=rhs.shape[rhsgroup].number) laerror("wrong index numbe
|
||||
if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction");
|
||||
if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
|
||||
|
||||
Tensor<T> u = rhs1.unwind_index(group,index);
|
||||
if(conjugate1) u.conjugateme();
|
||||
Tensor<T> rhsu = rhs.unwind_index(rhsgroup,rhsindex);
|
||||
const Tensor<T> u = conjugate1? (rhs1.unwind_index(group,index)).conjugate() : rhs1.unwind_index(group,index);
|
||||
const Tensor<T> rhsu = rhs.unwind_index(rhsgroup,rhsindex);
|
||||
|
||||
|
||||
NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2);
|
||||
@@ -1077,51 +1092,99 @@ return r;
|
||||
|
||||
|
||||
template<typename T>
|
||||
NRVec<NRMat<T> > Tensor<T>::Tucker(typename LA_traits<T>::normtype thr)
|
||||
NRVec<NRMat<T> > Tensor<T>::Tucker(typename LA_traits<T>::normtype thr, bool inverseorder)
|
||||
{
|
||||
int r=rank();
|
||||
NRVec<NRMat<T> > ret(r);
|
||||
if(r<2) return ret;
|
||||
if(r<1) laerror("illegal rank in Tucker");
|
||||
copyonwrite();
|
||||
|
||||
int rr=0;
|
||||
for(int i=0; i<shape.size(); ++i)
|
||||
for(int j=0; j<shape[i].number; ++j) //loop over all indices
|
||||
if(r==1) //create an analogous output for the trivial case
|
||||
{
|
||||
ret[0]=NRMat<T>(data,data.size(),1);
|
||||
shape[0].range=1;
|
||||
data.resize(calcsize());
|
||||
calcrank();
|
||||
data[0]=1;
|
||||
return ret;
|
||||
}
|
||||
|
||||
//loop over all indices; relies on the fact tha unwinding does not change order of remaining indices
|
||||
for(int i=0; i<r; ++i)
|
||||
{
|
||||
INDEX I=indexposition(i,shape);
|
||||
NRMat<T> um;
|
||||
NRVec<indexgroup> ushape;
|
||||
{
|
||||
Tensor<T> u=unwind_index(I);
|
||||
ushape=u.shape; ushape.copyonwrite();
|
||||
um=u.matrix();
|
||||
}
|
||||
int mini=um.nrows(); if(um.ncols()<mini) mini=um.ncols(); //compact SVD, expect descendingly sorted values
|
||||
NRMat<T> u(um.nrows(),mini),vt(mini,um.ncols());
|
||||
NRVec<typename LA_traits<T>::normtype> w(mini);
|
||||
singular_decomposition(um,&u,w,&vt,0);
|
||||
um.resize(0,0); //deallocate
|
||||
int preserve=mini;
|
||||
for(int k=0; k<mini; ++k) if(w[k]<thr) {preserve=k; break;}
|
||||
if(preserve==0) laerror("singular tensor in Tucker decomposition");
|
||||
NRMat<T> umnew;
|
||||
if(preserve<mini)
|
||||
{
|
||||
NRMat<T> um;
|
||||
NRVec<indexgroup> ushape;
|
||||
{
|
||||
Tensor<T> u=unwind_index(i,j);
|
||||
ushape=u.shape;
|
||||
um=u.matrix();
|
||||
}
|
||||
int mini=um.nrows(); if(um.ncols()<mini) mini=um.ncols(); //compact SVD, expect descendingly sorted values
|
||||
NRMat<T> u(um.nrows(),mini),vt(mini,um.ncols());
|
||||
NRVec<typename LA_traits<T>::normtype> w(mini);
|
||||
singular_decomposition(um,&u,w,&vt,0);
|
||||
um.resize(0,0); //deallocate
|
||||
int preserve=mini;
|
||||
for(int k=0; k<mini; ++k) if(w[k]<thr) {preserve=k; break;}
|
||||
if(preserve==0) laerror("singular tensor in Tucker decomposition");
|
||||
NRMat<T> umnew;
|
||||
if(preserve<mini)
|
||||
{
|
||||
vt=vt.submatrix(0,preserve-1,0,um.ncols()-1);
|
||||
w=w.subvector(0,preserve-1);
|
||||
umnew=u.submatrix(0,um.nrows()-1,0,preserve-1);
|
||||
}
|
||||
else umnew=u;
|
||||
ret[rr++]=vt.transpose(true);
|
||||
umnew.diagmultr(w);
|
||||
//rebuild tensor of the preserved shape from matrix
|
||||
ushape[0].range=preserve;
|
||||
NRVec<T> newdata(umnew);
|
||||
*this = Tensor(ushape,newdata);
|
||||
vt=vt.submatrix(0,preserve-1,0,um.ncols()-1);
|
||||
w=w.subvector(0,preserve-1);
|
||||
umnew=u.submatrix(0,um.nrows()-1,0,preserve-1);
|
||||
}
|
||||
else umnew=u;
|
||||
ret[(inverseorder? r-i-1 : i)]=vt.transpose(true);
|
||||
umnew.diagmultr(w);
|
||||
//rebuild tensor of the preserved shape from matrix
|
||||
ushape[0].range=preserve;
|
||||
{
|
||||
NRVec<T> newdata(umnew);
|
||||
umnew.resize(0,0);//deallocate
|
||||
*this = Tensor(ushape,newdata);
|
||||
}
|
||||
}
|
||||
if(!is_flat()) laerror("this should not happen");
|
||||
if(!inverseorder)
|
||||
{
|
||||
NRPerm<int> p(r);
|
||||
for(int i=1; i<=r; ++i) p[r-i+1]=i;
|
||||
*this = permute_index_groups(p);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::inverseTucker(const NRVec<NRMat<T> > &x, bool inverseorder) const
|
||||
{
|
||||
if(rank()!=x.size()) laerror("input of inverseTucker does not match rank");
|
||||
Tensor<T> tmp(*this);
|
||||
Tensor<T> r;
|
||||
if(!is_flat()) laerror("inverseTucker only for flat tensors as produced by Tucker");
|
||||
for(int i=0; i<rank(); ++i)
|
||||
{
|
||||
Tensor<T> mat(x[i],true);
|
||||
r= tmp.contraction(i,0,mat,0,0,(T)1,false,false);
|
||||
if(i<rank()-1)
|
||||
{
|
||||
tmp=r;
|
||||
r.deallocate();
|
||||
}
|
||||
}
|
||||
if(!inverseorder)
|
||||
{
|
||||
NRPerm<int> p(r.rank());
|
||||
for(int i=1; i<=r.rank(); ++i) p[r.rank()-i+1]=i;
|
||||
return r.permute_index_groups(p);
|
||||
}
|
||||
else
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user