tensor: implemented index names

This commit is contained in:
2025-11-05 16:09:29 +01:00
parent 7941b7a7e2
commit 5e4c4dddea
3 changed files with 156 additions and 19 deletions

View File

@@ -72,6 +72,7 @@ for(int i=0; i<shape.size(); ++i)
break;
}
}
if(s==0) laerror("empty tensor - perhaps antisymmetric in dim<rank");
return s;
}
@@ -660,6 +661,24 @@ help_t<T>->data[target] = *v;
}
//permutation of individual indices from permutation of index groups
NRPerm<int> group2flat_perm(const NRVec<indexgroup> &shape, const NRPerm<int> &p)
{
int rank=0;
for(int i=0; i<shape.size(); ++i) rank+=shape[i].number;
NRPerm<int> q(rank);
int ii=1;
for(int i=0; i<shape.size(); ++i)
{
int selgroup=p[i+1]-1;
for(int j=0; j<shape[selgroup].number; ++j)
{
q[ii++] = 1+flatposition(selgroup,j,shape);
}
}
return q;
}
template<typename T>
Tensor<T> Tensor<T>::permute_index_groups(const NRPerm<int> &p) const
@@ -668,6 +687,11 @@ Tensor<T> Tensor<T>::permute_index_groups(const NRPerm<int> &p) const
NRVec<indexgroup> newshape=shape.permuted(p,true);
//std::cout <<"permute_index_groups newshape = "<<newshape<<std::endl;
Tensor<T> r(newshape);
if(is_named())
{
NRPerm<int> q=group2flat_perm(shape,p);
r.names = names.permuted(q,true);
}
//prepare statics for the callback
help_p = &p;
@@ -773,6 +797,7 @@ for(int i=0; i<shape.size(); ++i)
else flatindex += shape[i].number;
}
//std::cout <<"unwind new shape = "<<newshape<<std::endl;
Tensor<T> r(newshape);
@@ -790,12 +815,18 @@ for(int i=2; i<=rank(); ++i)
}
if(!indexperm.is_valid())
{
std::cout << "indexperm = "<<indexperm<<std::endl;
//std::cout << "indexperm = "<<indexperm<<std::endl;
laerror("internal error 3 in unwind_index");
}
//std::cout <<"unwind permutation = "<<indexperm<<std::endl;
if(is_named())
{
r.names = names.permuted(indexperm,true);
//std::cout <<"unwind new names = "<<r.names<<std::endl;
}
//loop recursively and do the unwinding
help_tt<T> = this;
help_p = &indexperm;
@@ -892,10 +923,12 @@ for(int g=0; g<shape.size(); ++g)
}
}
std::cout <<"Flatten new shape = "<<newshape<<std::endl;
//std::cout <<"Flatten new shape = "<<newshape<<std::endl;
//decompress the tensor data
Tensor<T> r(newshape);
r.names=names;
help_tt<T> = this;
r.loopover(flatten_callback);
return r;
@@ -908,7 +941,7 @@ template<typename T>
Tensor<T> Tensor<T>::unwind_indices(const INDEXLIST &il) const
{
if(il.size()==0) return *this;
if(il.size()==1) return unwind_index(il[0].group,il[0].index);
if(il.size()==1) return unwind_index(il[0]);
for(int i=0; i<il.size(); ++i)
{
@@ -1017,6 +1050,12 @@ if(!indexperm.is_valid())
laerror("internal error 3 in unwind_indices");
}
if(is_named())
{
r.names = names.permuted(indexperm,true);
//std::cout <<"unwind new names = "<<r.names<<std::endl;
}
//loop recursively and do the unwinding
help_tt<T> = this;
help_p = &indexperm;
@@ -1078,15 +1117,28 @@ int ii=0;
for(int i=1; i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
for(int i=1; i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
NRVec<INDEXNAME> newnames;
if(u.is_named() && rhsu.is_named())
{
if(u.names[0]!=rhsu.names[0]) laerror("contraction index name mismatch in addcontraction");
newnames.resize(u.names.size()+rhsu.names.size()-2);
int ii=0;
for(int i=1; i<rhsu.names.size(); ++i) newnames[ii++] = rhsu.names[i];
for(int i=1; i<u.names.size(); ++i) newnames[ii++] = u.names[i];
}
if(doresize)
{
if(beta!= (T)0) laerror("resize in addcontraction requires beta=0");
resize(newshape);
names=newnames;
}
else
{
if(shape!=newshape) laerror("tensor shape mismatch in addcontraction");
if(is_named() && names!=newnames) laerror("remaining index names mismatch in addcontraction");
}
int nn,mm,kk;
kk=u.groupsizes[0];
if(kk!=rhsu.groupsizes[0]) laerror("internal error in addcontraction");
@@ -1097,6 +1149,7 @@ auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta,conjugate);
}
template<typename T>
void Tensor<T>::addcontractions(const Tensor &rhs1, const INDEXLIST &il1, const Tensor &rhs2, const INDEXLIST &il2, T alpha, T beta, bool doresize, bool conjugate1, bool conjugate2)
{
@@ -1125,15 +1178,29 @@ int ii=0;
for(int i=il1.size(); i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i];
for(int i=il1.size(); i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
NRVec<INDEXNAME> newnames;
if(u.is_named() && rhsu.is_named())
{
for(int i=0; i<il1.size(); ++i) if(u.names[i]!=rhsu.names[i]) laerror("contraction index name mismatch in addcontractions");
newnames.resize(u.names.size()+rhsu.names.size()-2*il1.size());
int ii=0;
for(int i=il1.size(); i<rhsu.names.size(); ++i) newnames[ii++] = rhsu.names[i];
for(int i=il1.size(); i<u.names.size(); ++i) newnames[ii++] = u.names[i];
}
if(doresize)
{
if(beta!= (T)0) laerror("resize in addcontractions requires beta=0");
resize(newshape);
names=newnames;
}
else
{
if(shape!=newshape) laerror("tensor shape mismatch in addcontraction");
if(is_named() && names!=newnames) laerror("remaining index names mismatch in addcontractions");
}
int nn,mm,kk;
kk=1;
int kk2=1;
@@ -1318,8 +1385,8 @@ for(int g=0; g<shape.size(); ++g)
p[gg++] = 1+g;
if(gg!=shape.size() || !p.is_valid()) laerror("internal error in merge_index_groups");
Tensor<T> r = permute_index_groups(p);
r.merge_adjacent_index_groups(0,groups.size()-1);
Tensor<T> r = permute_index_groups(p); //takes care of names permutation too
r.merge_adjacent_index_groups(0,groups.size()-1); //flat names invariant
return r;
}
@@ -1437,7 +1504,7 @@ int Tensor<T>::findflatindex(const INDEXNAME nam) const
if(!is_named()) laerror("tensor indices were not named");
for(int i=0; i<names.size(); ++i)
{
if(!strncmp(nam.name,names[i].name,N_INDEXNAME)) return i;
if(nam==names[i]) return i;
}
return -1;
}
@@ -1450,6 +1517,14 @@ if(n<0) laerror("index with this name was not found");
return indexposition(n,shape);
}
template<typename T>
NRVec<INDEX> Tensor<T>::findindexlist(const NRVec<INDEXNAME> &names) const
{
int n=names.size();
NRVec<INDEX> ind(n);
for(int i=0; i<n; ++i) ind[i] = findindex(names[i]);
return ind;
}