const version of loopover and grouploopoevr in tensor class

This commit is contained in:
Jiri Pittner 2025-10-23 16:15:42 +02:00
parent cd09d93c27
commit 3baced9adb
2 changed files with 94 additions and 7 deletions

View File

@ -376,6 +376,7 @@ calcsize();
}
template<typename T>
void loopingroups(Tensor<T> &t, int ngroup, int igroup, T **p, SUPERINDEX &I, void (*callback)(const SUPERINDEX &, T *))
{
@ -441,14 +442,78 @@ loopingroups(*this,ss,sh->number-1,&pp,I,callback);
}
template<typename T>
void constloopingroups(const Tensor<T> &t, int ngroup, int igroup, const T **p, SUPERINDEX &I, void (*callback)(const SUPERINDEX &, const T *))
{
LA_index istart,iend;
const indexgroup *sh = &t.shape[ngroup];
switch(sh->symmetry)
{
case 0:
istart= sh->offset;
iend= sh->offset+sh->range-1;
break;
case 1:
istart= sh->offset;
if(igroup==sh->number-1) iend= sh->offset+sh->range-1;
else iend = I[ngroup][igroup+1];
break;
case -1:
istart= sh->offset + igroup;
if(igroup==sh->number-1) iend= sh->offset+sh->range-1;
else iend = I[ngroup][igroup+1]-1;
break;
}
for(LA_index i = istart; i<=iend; ++i)
{
I[ngroup][igroup]=i;
if(ngroup==0 && igroup==0)
{
int sign;
//std::cout <<"TEST "<<t.index(&sign,I)<<" ";
(*callback)(I,(*p)++);
}
else
{
int newigroup= igroup-1;
int newngroup=ngroup;
if(newigroup<0)
{
--newngroup;
const indexgroup *sh2 = &(* const_cast<const NRVec<indexgroup> *>(&t.shape))[newngroup];
newigroup=sh2->number-1;
}
constloopingroups(t,newngroup,newigroup,p,I,callback);
}
}
}
template<typename T>
void Tensor<T>::constloopover(void (*callback)(const SUPERINDEX &, const T *)) const
{
SUPERINDEX I(shape.size());
for(int i=0; i<I.size(); ++i)
{
const indexgroup *sh = &shape[i];
I[i].resize(sh->number);
I[i] = sh->offset;
}
const T *pp=&data[0];
int ss=shape.size()-1;
const indexgroup *sh = &shape[ss];
constloopingroups(*this,ss,sh->number-1,&pp,I,callback);
}
static std::ostream *sout;
template<typename T>
static void outputcallback(const SUPERINDEX &I, T *v)
static void outputcallback(const SUPERINDEX &I, const T *v)
{
//print indices flat
for(int i=0; i<I.size(); ++i)
for(int j=0; j<I[i].size(); ++j) *sout << I[i][j]<<" ";
//*sout<<" "<< " "<<(void *)v<<" "<< *v<<std::endl;
*sout<<" "<< *v<<std::endl;
}
@ -480,7 +545,7 @@ std::ostream & operator<<(std::ostream &s, const Tensor<T> &x)
{
s<<x.shape;
sout= &s;
const_cast<Tensor<T> *>(&x)->loopover(&outputcallback<T>);
x.constloopover(&outputcallback<T>);
return s;
}
@ -521,6 +586,28 @@ loopovergroups(*this,shape.size()-1,&pp,I,callback);
}
template<typename T>
void constloopovergroups(const Tensor<T> &t, int ngroup, const T **p, GROUPINDEX &I, void (*callback)(const GROUPINDEX &, const T *))
{
for(LA_largeindex i = 0; i<t.groupsizes[ngroup]; ++i)
{
I[ngroup]=i;
if(ngroup==0) (*callback)(I,(*p)++);
else constloopovergroups(t,ngroup-1,p,I,callback);
}
}
template<typename T>
void Tensor<T>::constgrouploopover(void (*callback)(const GROUPINDEX &, const T *)) const
{
GROUPINDEX I(shape.size());
const T *pp= &data[0];
constloopovergroups(*this,shape.size()-1,&pp,I,callback);
}
const NRPerm<int> *help_p;
template<typename T>
Tensor<T> *help_t;
@ -528,7 +615,7 @@ template<typename T>
const Tensor<T> *help_tt;
template<typename T>
static void permutecallback(const GROUPINDEX &I, T *v)
static void permutecallback(const GROUPINDEX &I, const T *v)
{
LA_largeindex target=0;
for(int i=0; i< help_t<T>->shape.size(); ++i)
@ -553,7 +640,7 @@ help_p = &p;
help_t<T> = &r;
//now rearrange the data
const_cast<Tensor<T> *>(this)->grouploopover(permutecallback<T>);
const_cast<Tensor<T> *>(this)->constgrouploopover(permutecallback<T>);
return r;
}

View File

@ -49,8 +49,6 @@
//@@@conversions to/from fourindex, optional negarive rande for beta spin handling
//@@@ optional distinguish covariant and contravariant check in contraction
//
//maybe const loopover and grouploopover to avoid problems with shallowly copied tensors
//
//@@@?general permutation of individual indices - check the indices in sym groups remain adjacent, calculate result's shape, loopover the result and permute using unwind_callback
//
//
@ -224,7 +222,9 @@ public:
inline void randomize(const typename LA_traits<T>::normtype &x) {data.randomize(x);};
void loopover(void (*callback)(const SUPERINDEX &, T *)); //loop over all elements
void constloopover(void (*callback)(const SUPERINDEX &, const T *)) const; //loop over all elements
void grouploopover(void (*callback)(const GROUPINDEX &, T *)); //loop over all elements disregarding the internal structure of index groups
void constgrouploopover(void (*callback)(const GROUPINDEX &, const T *)) const; //loop over all elements disregarding the internal structure of index groups
Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one (the leftmost one)