diff --git a/mat.cc b/mat.cc index 2707835..04e7f52 100644 --- a/mat.cc +++ b/mat.cc @@ -779,6 +779,7 @@ void NRMat< complex >::diagmultr(const NRVec< complex > &rhs) +#ifdef oldversion // Mat * Smat, decomposed to nn x Vec * Smat template<> const NRMat @@ -810,6 +811,37 @@ NRMat< complex >::operator*(const NRSMat< complex > &rhs) const return result; } +#else + +// Mat * Smat +template<> +const NRMat +NRMat::operator*(const NRSMat &rhs) const +{ +#ifdef DEBUG + if (mm != rhs.nrows()) laerror("incompatible dimension in Mat*SMat"); +#endif + NRMat result(nn, rhs.ncols()); + cblas_dsymm(CblasRowMajor, CblasRight, CblasLower, nn, mm, 1., &rhs[0],mm,(*this)[0],mm,0.,result[0],mm); + return result; +} + +template<> +const NRMat< complex > +NRMat< complex >::operator*(const NRSMat< complex > &rhs) const +{ +#ifdef DEBUG + if (mm != rhs.nrows()) laerror("incompatible dimension in Mat*SMat"); +#endif + NRMat< complex > result(nn, rhs.ncols()); + cblas_zhemm(CblasRowMajor, CblasRight, CblasLower, nn, mm, (void *)&CONE, &rhs[0],mm,(*this)[0],mm,(void *)&CZERO,result[0],mm); + return result; +} + + + +#endif + // sum of rows template<> diff --git a/smat.cc b/smat.cc index cb30cd6..de3b2fd 100644 --- a/smat.cc +++ b/smat.cc @@ -148,6 +148,7 @@ void NRSMat::fscanf(FILE *f, const char *format) * BLAS specializations for double and complex */ +#ifdef oldversion // SMat * Mat template<> const NRMat NRSMat::operator*(const NRMat &rhs) const @@ -177,6 +178,36 @@ NRSMat< complex >::operator*(const NRMat< complex > &rhs) const return result; } +#else + +// SMat * Mat +template<> +const NRMat NRSMat::operator*(const NRMat &rhs) const +{ +#ifdef DEBUG + if (nn != rhs.nrows()) laerror("incompatible dimensions in SMat*Mat"); +#endif + NRMat result(nn, rhs.ncols()); + cblas_dsymm(CblasRowMajor, CblasLeft, CblasLower, nn, rhs.ncols(), 1., (*this),nn, rhs[0],rhs.ncols(), 0.,result[0],rhs.ncols()); + return result; +} + + +template<> +const NRMat< complex > +NRSMat< complex >::operator*(const NRMat< complex > &rhs) const +{ +#ifdef DEBUG + if (nn != rhs.nrows()) laerror("incompatible dimensions in SMat*Mat"); +#endif + NRMat< complex > result(nn, rhs.ncols()); + cblas_zhemm(CblasRowMajor, CblasLeft, CblasLower, nn, rhs.ncols(), &CONE, (*this),nn, rhs[0],rhs.ncols(), &CZERO,result[0],rhs.ncols()); + return result; +} + + +#endif + // SMat * SMat template<>