tensor: implemented index names
This commit is contained in:
87
tensor.cc
87
tensor.cc
@@ -72,6 +72,7 @@ for(int i=0; i<shape.size(); ++i)
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(s==0) laerror("empty tensor - perhaps antisymmetric in dim<rank");
|
||||
return s;
|
||||
}
|
||||
|
||||
@@ -660,6 +661,24 @@ help_t<T>->data[target] = *v;
|
||||
}
|
||||
|
||||
|
||||
//permutation of individual indices from permutation of index groups
|
||||
NRPerm<int> group2flat_perm(const NRVec<indexgroup> &shape, const NRPerm<int> &p)
|
||||
{
|
||||
int rank=0;
|
||||
for(int i=0; i<shape.size(); ++i) rank+=shape[i].number;
|
||||
NRPerm<int> q(rank);
|
||||
int ii=1;
|
||||
for(int i=0; i<shape.size(); ++i)
|
||||
{
|
||||
int selgroup=p[i+1]-1;
|
||||
for(int j=0; j<shape[selgroup].number; ++j)
|
||||
{
|
||||
q[ii++] = 1+flatposition(selgroup,j,shape);
|
||||
}
|
||||
}
|
||||
return q;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::permute_index_groups(const NRPerm<int> &p) const
|
||||
@@ -668,6 +687,11 @@ Tensor<T> Tensor<T>::permute_index_groups(const NRPerm<int> &p) const
|
||||
NRVec<indexgroup> newshape=shape.permuted(p,true);
|
||||
//std::cout <<"permute_index_groups newshape = "<<newshape<<std::endl;
|
||||
Tensor<T> r(newshape);
|
||||
if(is_named())
|
||||
{
|
||||
NRPerm<int> q=group2flat_perm(shape,p);
|
||||
r.names = names.permuted(q,true);
|
||||
}
|
||||
|
||||
//prepare statics for the callback
|
||||
help_p = &p;
|
||||
@@ -773,6 +797,7 @@ for(int i=0; i<shape.size(); ++i)
|
||||
else flatindex += shape[i].number;
|
||||
}
|
||||
|
||||
|
||||
//std::cout <<"unwind new shape = "<<newshape<<std::endl;
|
||||
|
||||
Tensor<T> r(newshape);
|
||||
@@ -790,12 +815,18 @@ for(int i=2; i<=rank(); ++i)
|
||||
}
|
||||
if(!indexperm.is_valid())
|
||||
{
|
||||
std::cout << "indexperm = "<<indexperm<<std::endl;
|
||||
//std::cout << "indexperm = "<<indexperm<<std::endl;
|
||||
laerror("internal error 3 in unwind_index");
|
||||
}
|
||||
|
||||
//std::cout <<"unwind permutation = "<<indexperm<<std::endl;
|
||||
|
||||
if(is_named())
|
||||
{
|
||||
r.names = names.permuted(indexperm,true);
|
||||
//std::cout <<"unwind new names = "<<r.names<<std::endl;
|
||||
}
|
||||
|
||||
//loop recursively and do the unwinding
|
||||
help_tt<T> = this;
|
||||
help_p = &indexperm;
|
||||
@@ -892,10 +923,12 @@ for(int g=0; g<shape.size(); ++g)
|
||||
}
|
||||
}
|
||||
|
||||
std::cout <<"Flatten new shape = "<<newshape<<std::endl;
|
||||
//std::cout <<"Flatten new shape = "<<newshape<<std::endl;
|
||||
|
||||
|
||||
//decompress the tensor data
|
||||
Tensor<T> r(newshape);
|
||||
r.names=names;
|
||||
help_tt<T> = this;
|
||||
r.loopover(flatten_callback);
|
||||
return r;
|
||||
@@ -908,7 +941,7 @@ template<typename T>
|
||||
Tensor<T> Tensor<T>::unwind_indices(const INDEXLIST &il) const
|
||||
{
|
||||
if(il.size()==0) return *this;
|
||||
if(il.size()==1) return unwind_index(il[0].group,il[0].index);
|
||||
if(il.size()==1) return unwind_index(il[0]);
|
||||
|
||||
for(int i=0; i<il.size(); ++i)
|
||||
{
|
||||
@@ -1017,6 +1050,12 @@ if(!indexperm.is_valid())
|
||||
laerror("internal error 3 in unwind_indices");
|
||||
}
|
||||
|
||||
if(is_named())
|
||||
{
|
||||
r.names = names.permuted(indexperm,true);
|
||||
//std::cout <<"unwind new names = "<<r.names<<std::endl;
|
||||
}
|
||||
|
||||
//loop recursively and do the unwinding
|
||||
help_tt<T> = this;
|
||||
help_p = &indexperm;
|
||||
@@ -1078,15 +1117,28 @@ int ii=0;
|
||||
for(int i=1; i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
|
||||
for(int i=1; i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
|
||||
|
||||
NRVec<INDEXNAME> newnames;
|
||||
if(u.is_named() && rhsu.is_named())
|
||||
{
|
||||
if(u.names[0]!=rhsu.names[0]) laerror("contraction index name mismatch in addcontraction");
|
||||
newnames.resize(u.names.size()+rhsu.names.size()-2);
|
||||
int ii=0;
|
||||
for(int i=1; i<rhsu.names.size(); ++i) newnames[ii++] = rhsu.names[i];
|
||||
for(int i=1; i<u.names.size(); ++i) newnames[ii++] = u.names[i];
|
||||
}
|
||||
|
||||
if(doresize)
|
||||
{
|
||||
if(beta!= (T)0) laerror("resize in addcontraction requires beta=0");
|
||||
resize(newshape);
|
||||
names=newnames;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(shape!=newshape) laerror("tensor shape mismatch in addcontraction");
|
||||
if(is_named() && names!=newnames) laerror("remaining index names mismatch in addcontraction");
|
||||
}
|
||||
|
||||
int nn,mm,kk;
|
||||
kk=u.groupsizes[0];
|
||||
if(kk!=rhsu.groupsizes[0]) laerror("internal error in addcontraction");
|
||||
@@ -1097,6 +1149,7 @@ auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta,conjugate);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
void Tensor<T>::addcontractions(const Tensor &rhs1, const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha, T beta, bool doresize, bool conjugate1, bool conjugate2)
|
||||
{
|
||||
@@ -1125,15 +1178,29 @@ int ii=0;
|
||||
for(int i=il1.size(); i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
|
||||
for(int i=il1.size(); i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
|
||||
|
||||
NRVec<INDEXNAME> newnames;
|
||||
if(u.is_named() && rhsu.is_named())
|
||||
{
|
||||
for(int i=0; i<il1.size(); ++i) if(u.names[i]!=rhsu.names[i]) laerror("contraction index name mismatch in addcontractions");
|
||||
newnames.resize(u.names.size()+rhsu.names.size()-2*il1.size());
|
||||
int ii=0;
|
||||
for(int i=il1.size(); i<rhsu.names.size(); ++i) newnames[ii++] = rhsu.names[i];
|
||||
for(int i=il1.size(); i<u.names.size(); ++i) newnames[ii++] = u.names[i];
|
||||
}
|
||||
|
||||
|
||||
if(doresize)
|
||||
{
|
||||
if(beta!= (T)0) laerror("resize in addcontractions requires beta=0");
|
||||
resize(newshape);
|
||||
names=newnames;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(shape!=newshape) laerror("tensor shape mismatch in addcontraction");
|
||||
if(is_named() && names!=newnames) laerror("remaining index names mismatch in addcontractions");
|
||||
}
|
||||
|
||||
int nn,mm,kk;
|
||||
kk=1;
|
||||
int kk2=1;
|
||||
@@ -1318,8 +1385,8 @@ for(int g=0; g<shape.size(); ++g)
|
||||
p[gg++] = 1+g;
|
||||
if(gg!=shape.size() || !p.is_valid()) laerror("internal error in merge_index_groups");
|
||||
|
||||
Tensor<T> r = permute_index_groups(p);
|
||||
r.merge_adjacent_index_groups(0,groups.size()-1);
|
||||
Tensor<T> r = permute_index_groups(p); //takes care of names permutation too
|
||||
r.merge_adjacent_index_groups(0,groups.size()-1); //flat names invariant
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -1437,7 +1504,7 @@ int Tensor<T>::findflatindex(const INDEXNAME nam) const
|
||||
if(!is_named()) laerror("tensor indices were not named");
|
||||
for(int i=0; i<names.size(); ++i)
|
||||
{
|
||||
if(!strncmp(nam.name,names[i].name,N_INDEXNAME)) return i;
|
||||
if(nam==names[i]) return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
@@ -1450,6 +1517,14 @@ if(n<0) laerror("index with this name was not found");
|
||||
return indexposition(n,shape);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
NRVec<INDEX> Tensor<T>::findindexlist(const NRVec<INDEXNAME> &names) const
|
||||
{
|
||||
int n=names.size();
|
||||
NRVec<INDEX> ind(n);
|
||||
for(int i=0; i<n; ++i) ind[i] = findindex(names[i]);
|
||||
return ind;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user