implemented complex conjugation method where it was not available yet

This commit is contained in:
2024-05-03 16:56:21 +02:00
parent 518c75fb20
commit 3ba6d03eee
8 changed files with 129 additions and 11 deletions

View File

@@ -647,26 +647,26 @@ return r;
template<typename T>
static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0) //R(nn,mm) = A * B^T
static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0, bool conjugate=false) //R(nn,mm) = A * B^T
{
for(int i=0; i<nn; ++i) for(int j=0; j<mm; ++j)
{
if(beta==0) r[i*mm+j]=0; else r[i*mm+j] *= beta;
for(int k=0; k<kk; ++k) r[i*mm+j] += alpha * a[i*kk+k] * b[j*kk+k];
for(int k=0; k<kk; ++k) r[i*mm+j] += alpha * a[i*kk+k] * conjugate? LA_traits<T>::conjugate(b[j*kk+k]) : b[j*kk+k];
}
}
template<>
void auxmatmult<double>(int nn, int mm, int kk, double *r, double *a, double *b, double alpha, double beta)
void auxmatmult<double>(int nn, int mm, int kk, double *r, double *a, double *b, double alpha, double beta, bool conjugate)
{
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, alpha, a, kk, b, kk, beta, r, mm);
}
template<>
void auxmatmult<std::complex<double> >(int nn, int mm, int kk, std::complex<double> *r, std::complex<double> *a, std::complex<double> *b, std::complex<double> alpha, std::complex<double> beta)
void auxmatmult<std::complex<double> >(int nn, int mm, int kk, std::complex<double> *r, std::complex<double> *a, std::complex<double> *b, std::complex<double> alpha, std::complex<double> beta, bool conjugate)
{
cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, &alpha, a, kk, b, kk, &beta, r, mm);
cblas_zgemm(CblasRowMajor, CblasNoTrans, (conjugate?CblasConjTrans:CblasTrans), nn, mm, kk, &alpha, a, kk, b, kk, &beta, r, mm);
}
@@ -679,7 +679,7 @@ cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, &alpha, a, kk,
//The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices
//
template<typename T>
void Tensor<T>::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize)
void Tensor<T>::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize, bool conjugate)
{
if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction");
if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction");
@@ -711,7 +711,7 @@ kk=u.groupsizes[0];
if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction");
nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i];
mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i];
auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta);
auxmatmult<T>(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta,conjugate);
}