tensor: support for complex (anti)hermitian tensors

This commit is contained in:
2025-11-18 17:30:58 +01:00
parent 417a7d1d1a
commit 20a61e2fb9
5 changed files with 101 additions and 36 deletions

View File

@@ -62,6 +62,8 @@ for(int i=0; i<shape.size(); ++i)
case 0:
s *= groupsizes[i] = longpow(sh->range,sh->number);
break;
case 2:
case -2:
case 1:
s *= groupsizes[i] = simplicial(sh->number,sh->range);
break;
@@ -104,18 +106,17 @@ switch(I.size()) //a few special cases for efficiency
break;
case 2:
{
*sign=1;
if(g.symmetry==0) return (I[1]-g.offset)*g.range+I[0]-g.offset;
if(g.symmetry==0) {*sign=1; return (I[1]-g.offset)*g.range+I[0]-g.offset;};
LA_index i0,i1;
if(I[0]>I[1]) {i1=I[0]; i0=I[1]; if(g.symmetry<0) *sign = -1;} else {i1=I[1]; i0=I[0];}
if(I[0]>I[1]) {i1=I[0]; i0=I[1]; *sign=g.symmetry;} else {i1=I[1]; i0=I[0]; *sign=1;}
i0 -= g.offset;
i1 -= g.offset;
if(g.symmetry<0)
if(g.symmetry == -1) //antisymmetric
{
if(i0==i1) {*sign=0; return -1;}
return i1*(i1-1)/2+i0;
}
else
else //symmetric, hermitian, antihermitian
{
return i1*(i1+1)/2+i0;
}
@@ -124,10 +125,9 @@ switch(I.size()) //a few special cases for efficiency
default: //general case
{
*sign=1;
if(g.symmetry==0) //rectangular case
{
*sign=1;
LA_largeindex r=0;
for(int i=I.size()-1; i>=0; --i)
{
@@ -143,8 +143,8 @@ switch(I.size()) //a few special cases for efficiency
II.copyonwrite();
if(g.offset!=0) II -= g.offset;
int parity=netsort(II.size(),&II[0]);
if(g.symmetry<0 && (parity&1)) *sign= -1;
if(g.symmetry<0) //antisymmetric
*sign= (parity&1) ? g.symmetry : 1;
if(g.symmetry == -1) //antisymmetric - do not store zero diagonal
{
for(int i=0; i<I.size()-1; ++i)
if(II[i]==II[i+1])
@@ -154,7 +154,7 @@ switch(I.size()) //a few special cases for efficiency
for(int i=0; i<II.size(); ++i) r += simplicial(i+1,II[i]-i);
return r;
}
else //symmetric
else //symmetric, hermitian, antihermitian
{
LA_largeindex r=0;
for(int i=0; i<II.size(); ++i) r += simplicial(i+1,II[i]);
@@ -181,6 +181,8 @@ switch(g.symmetry)
s /= g.range;
}
break;
case 2:
case -2:
case 1:
for(int i=g.number; i>0; --i)
{
@@ -221,6 +223,21 @@ for(int g=shape.size()-1; g>=0; --g)
return I;
}
//group-like multiplication table to combine symmetry adjustments due to several index groups
static const int signmultab[5][5] = {
{1,2,0,-2,-1},
{2,1,0,-1,-2},
{0,0,0,0,0},
{-2,-1,0,1,2},
{-1,-2,0,2,1}
};
static inline int signmult(int s1, int s2)
{
return signmultab[s1+2][s2+2];
}
template<typename T>
@@ -250,7 +267,7 @@ for(int g=0; g<shape.size(); ++g) //loop over index groups
int gsign;
LA_largeindex groupindex = subindex(&gsign,shape[g],I[g]);
//std::cout <<"INDEX TEST group "<<g<<" cumsizes "<< cumsizes[g]<<" groupindex "<<groupindex<<std::endl;
*sign *= gsign;
if(LA_traits<T>::is_complex()) *sign = signmult(*sign,gsign); else *sign *= gsign;
if(groupindex == -1) return -1;
r += groupindex * cumsizes[g];
}
@@ -276,7 +293,7 @@ for(int g=0; g<shape.size(); ++g) //loop over index groups
gstart=gend+1;
LA_largeindex groupindex = subindex(&gsign,shape[g],subI);
//std::cout <<"FLATINDEX TEST group "<<g<<" cumsizes "<< cumsizes[g]<<" groupindex "<<groupindex<<std::endl;
*sign *= gsign;
if(LA_traits<T>::is_complex()) *sign = signmult(*sign,gsign); else *sign *= gsign;
if(groupindex == -1) return -1;
r += groupindex * cumsizes[g];
}
@@ -408,6 +425,8 @@ switch(sh->symmetry)
istart= sh->offset;
iend= sh->offset+sh->range-1;
break;
case 2:
case -2:
case 1:
istart= sh->offset;
if(igroup==sh->number-1) iend= sh->offset+sh->range-1;
@@ -473,6 +492,8 @@ switch(sh->symmetry)
istart= sh->offset;
iend= sh->offset+sh->range-1;
break;
case 2:
case -2:
case 1:
istart= sh->offset;
if(igroup==sh->number-1) iend= sh->offset+sh->range-1;
@@ -1135,7 +1156,7 @@ if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in con
if(rhs1.shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction");
if(rhs1.shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction");
if(rhs1.shape[group].symmetry != rhs.shape[rhsgroup].symmetry) laerror("incompatible index symmetry in addgroupcontraction");
if(rhs1.shape[group].symmetry == 1) laerror("addgroupcontraction not implemented for symmetric index groups");
if(rhs1.shape[group].symmetry !=0 && rhs1.shape[group].symmetry != -1) laerror("addgroupcontraction only implemented for nonsymmetric and antisymmetric index groups");
#ifdef LA_TENSOR_INDEXPOSITION
if(rhs1.shape[group].upperindex ^ rhs.shape[rhsgroup].upperindex == false) laerror("can contact only upper with lower index");
#endif
@@ -1179,6 +1200,8 @@ if(kk!=rhsu.groupsizes[0]) laerror("internal error in addgroupcontraction");
T factor=alpha;
if(u.shape[0].symmetry== -1) factor=alpha*(T)factorial(u.shape[0].number);
if(u.shape[0].symmetry== 1) laerror("addgroupcontraction not implemented for symmetric index groups");
if(u.shape[0].symmetry== 2) laerror("addgroupcontraction not implemented for hermitean index groups");
if(u.shape[0].symmetry== -2) laerror("addgroupcontraction not implemented for antihermitean index groups");
nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i];
mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
data.copyonwrite();
@@ -1645,7 +1668,7 @@ if(is_named() && rhs.is_named() && names!=rhs.names) laerror("incompatible tenso
T factor=1;
for(int i=0; i<shape.size(); ++i)
{
if(shape[i].symmetry==1) laerror("unsupported index group symmetry in dot");
if(shape[i].symmetry==1||shape[i].symmetry==2||shape[i].symmetry== -2) laerror("unsupported index group symmetry in dot");
if(shape[i].symmetry== -1) factor *= (T)factorial(shape[i].number);
}
return factor * data.dot(rhs.data);
@@ -1897,8 +1920,8 @@ const INDEXGROUP *sh = &(* const_cast<const NRVec<INDEXGROUP> *>(&shape))[0];
for(int i=0; i<shape.size(); ++i)
{
if(sh[i].number==1 && sh[i].symmetry!=0) {shape.copyonwrite(); shape[i].symmetry=0;}
if(sh[i].symmetry>1 ) {shape.copyonwrite(); shape[i].symmetry=1;}
if(sh[i].symmetry<-1) {shape.copyonwrite(); shape[i].symmetry= -1;}
int maxlegal = LA_traits<T>::is_complex() ? 2 : 1;
if(sh[i].symmetry> maxlegal || sh[i].symmetry< -maxlegal) laerror("illegal index group symmetry specified");
}
}