diff --git a/mat.cc b/mat.cc index a41cf70..b27582d 100644 --- a/mat.cc +++ b/mat.cc @@ -2078,11 +2078,15 @@ NRMat< std::complex >::operator*(const NRSMat< std::complex > &r /***************************************************************************//** - * conjugate this non-complex matrix \f$A\f$, i.e. do nothing :-) + * conjugate this matrix * @return reference to the (unmodified) matrix ******************************************************************************/ template NRMat& NRMat::conjugateme() { +#ifdef CUDALA + if(location != cpu) laerror("general conjugation only on CPU"); +#endif + for(int i=0; i::conjugate((*this)(i,j)); return *this; } diff --git a/mat.h b/mat.h index c76902b..57f445d 100644 --- a/mat.h +++ b/mat.h @@ -333,7 +333,7 @@ public: //! in case of square matrix, transpose the leading minor of order n NRMat& transposeme(const int n = 0); - //! conjugate a square matrix + //! conjugate a matrix NRMat& conjugateme(); //! transpose this matrix and return the result by value diff --git a/smat.cc b/smat.cc index 4f5e9ef..167f12a 100644 --- a/smat.cc +++ b/smat.cc @@ -866,6 +866,58 @@ NRSMat > NRSMat >::inverse() {return +/***************************************************************************//** + * conjugate this general matrix + * @return reference to the (unmodified) matrix + ******************************************************************************/ +template +NRSMat& NRSMat::conjugateme() { +#ifdef CUDALA + if(location != cpu) laerror("general conjugation only on CPU"); +#endif + for(int i=0; i::conjugate(v[i]); + return *this; +} + + +/***************************************************************************//** + * conjugate this complex matrix + * @return reference to the modified matrix + ******************************************************************************/ +template<> +NRSMat >& NRSMat >::conjugateme() { + copyonwrite(); +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_dscal((size_t)NN2, -1.0, ((double *)v) + 1, 2); +#ifdef CUDALA + }else{ + cublasDscal((size_t)NN2, -1.0, ((double *)v) + 1, 2); + } +#endif + return *this; +} + +template<> +NRSMat >& NRSMat >::conjugateme() { + copyonwrite(); +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_sscal((size_t)NN2, -1.0, ((float *)v) + 1, 2); +#ifdef CUDALA + }else{ + cublasSscal((size_t)NN2, -1.0, ((float *)v) + 1, 2); + } +#endif + return *this; +} + + + + + /***************************************************************************//** * forced instantization in the corresponding object file diff --git a/smat.h b/smat.h index 424c115..c891b33 100644 --- a/smat.h +++ b/smat.h @@ -99,6 +99,10 @@ public: //! inverse matrix NRSMat inverse(); + //! conjugate a matrix + NRSMat& conjugateme(); + const NRSMat conjugate() const {NRSMat r(*this); r.conjugateme(); return r;}; + //! permute matrix elements const NRSMat permuted(const NRPerm &p, const bool inverse=false) const; diff --git a/tensor.cc b/tensor.cc index dfed10d..d1e82f6 100644 --- a/tensor.cc +++ b/tensor.cc @@ -647,26 +647,26 @@ return r; template -static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0) //R(nn,mm) = A * B^T +static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b, T alpha=1, T beta=0, bool conjugate=false) //R(nn,mm) = A * B^T { for(int i=0; i::conjugate(b[j*kk+k]) : b[j*kk+k]; } } template<> -void auxmatmult(int nn, int mm, int kk, double *r, double *a, double *b, double alpha, double beta) +void auxmatmult(int nn, int mm, int kk, double *r, double *a, double *b, double alpha, double beta, bool conjugate) { cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, alpha, a, kk, b, kk, beta, r, mm); } template<> -void auxmatmult >(int nn, int mm, int kk, std::complex *r, std::complex *a, std::complex *b, std::complex alpha, std::complex beta) +void auxmatmult >(int nn, int mm, int kk, std::complex *r, std::complex *a, std::complex *b, std::complex alpha, std::complex beta, bool conjugate) { -cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, &alpha, a, kk, b, kk, &beta, r, mm); +cblas_zgemm(CblasRowMajor, CblasNoTrans, (conjugate?CblasConjTrans:CblasTrans), nn, mm, kk, &alpha, a, kk, b, kk, &beta, r, mm); } @@ -679,7 +679,7 @@ cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, &alpha, a, kk, //The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices // template -void Tensor::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize) +void Tensor::addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha, T beta, bool doresize, bool conjugate) { if(group<0||group>=rhs1.shape.size()) laerror("wrong group number in contraction"); if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction"); @@ -711,7 +711,7 @@ kk=u.groupsizes[0]; if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction"); nn=1; for(int i=1; i(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta); +auxmatmult(nn,mm,kk,&data[0],&u.data[0], &rhsu.data[0],alpha,beta,conjugate); } diff --git a/tensor.h b/tensor.h index 9926dae..b34b171 100644 --- a/tensor.h +++ b/tensor.h @@ -147,6 +147,10 @@ public: inline Tensor& operator/=(const T &a) {data/=a; return *this;}; inline Tensor operator/(const T &a) const {Tensor r(*this); r /=a; return r;}; + Tensor& conjugateme() {data.conjugateme(); return *this;}; + inline Tensor conjugate() const {Tensor r(*this); r.conjugateme(); return r;}; + + inline Tensor& operator+=(const Tensor &rhs) { @@ -180,8 +184,8 @@ public: Tensor permute_index_groups(const NRPerm &p) const; //rearrange the tensor storage permuting index groups as a whole Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one - void addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, T beta=1, bool doresize=false); - inline Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1) const {Tensor r; r.addcontraction(*this,group,index,rhs,rhsgroup,rhsindex,alpha,0,true); return r; } + void addcontraction(const Tensor &rhs1, int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, T beta=1, bool doresize=false, bool conjugate=false); + inline Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1, bool conjugate=false) const {Tensor r; r.addcontraction(*this,group,index,rhs,rhsgroup,rhsindex,alpha,0,true, conjugate); return r; } void apply_permutation_algebra(const Tensor &rhs, const PermutationAlgebra &pa, bool inverse=false, T alpha=1, T beta=0); //general (not optimally efficient) symmetrizers, antisymmetrizers etc. acting on the flattened index list: // this *=beta; for I over this: this(I) += alpha * sum_P c_P rhs(P(I)) diff --git a/vec.cc b/vec.cc index c0951cf..af8283e 100644 --- a/vec.cc +++ b/vec.cc @@ -903,6 +903,57 @@ void NRVec::storesubvector(const NRVec &selection, const NRVec &rhs) } } +/***************************************************************************//** + * conjugate this general vector + * @return reference to the (unmodified) matrix + ******************************************************************************/ +template +NRVec& NRVec::conjugateme() { +#ifdef CUDALA + if(location != cpu) laerror("general conjugation only on CPU"); +#endif + for(int i=0; i::conjugate(v[i]); + return *this; +} + + +/***************************************************************************//** + * conjugate this complex vector + * @return reference to the modified matrix + ******************************************************************************/ +template<> +NRVec >& NRVec >::conjugateme() { + copyonwrite(); +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_dscal((size_t)nn, -1.0, ((double *)v) + 1, 2); +#ifdef CUDALA + }else{ + cublasDscal((size_t)nn, -1.0, ((double *)v) + 1, 2); + } +#endif + return *this; +} + +template<> +NRVec >& NRVec >::conjugateme() { + copyonwrite(); +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_sscal((size_t)nn, -1.0, ((float *)v) + 1, 2); +#ifdef CUDALA + }else{ + cublasSscal((size_t)nn, -1.0, ((float *)v) + 1, 2); + } +#endif + return *this; +} + + + + /***************************************************************************//** diff --git a/vec.h b/vec.h index 34b5b0f..c043007 100644 --- a/vec.h +++ b/vec.h @@ -298,6 +298,9 @@ public: v[0] = a; } + //! complex conjugate + NRVec& conjugateme(); + inline NRVec conjugate() const {NRVec r(*this); r.conjugateme(); return r;}; //! determine the actual value of the reference counter inline int getcount() const {return count?*count:0;}