#include "mat.h" // TODO : // ////////////////////////////////////////////////////////////////////////////// //// forced instantization in the corresponding object file template NRMat; template NRMat< complex >; template NRMat; template NRMat; /* * Templates first, specializations for BLAS next */ // Assign diagonal template NRMat & NRMat::operator=(const T &a) { copyonwrite(); #ifdef DEBUG if (nn != mm) laerror("RMat.operator=scalar on non-square matrix"); #endif #ifdef MATPTR for (int i=0; i< nn; i++) v[i][i] = a; #else for (int i=0; i< nn*nn; i+=nn+1) v[i] = a; #endif return *this; } // M += a template NRMat & NRMat::operator+=(const T &a) { copyonwrite(); #ifdef DEBUG if (nn != mm) laerror("Mat.operator+=scalar on non-square matrix"); #endif #ifdef MATPTR for (int i=0; i< nn; i++) v[i][i] += a; #else for (int i=0; i< nn*nn; i+=nn+1) v[i] += a; #endif return *this; } // M -= a template NRMat & NRMat::operator-=(const T &a) { copyonwrite(); #ifdef DEBUG if (nn != mm) laerror("Mat.operator-=scalar on non-square matrix"); #endif #ifdef MATPTR for (int i=0; i< nn; i++) v[i][i] -= a; #else for (int i=0; i< nn*nn; i+=nn+1) v[i] -= a; #endif return *this; } // unary minus template const NRMat NRMat::operator-() const { NRMat result(nn, mm); #ifdef MATPTR for (int i=0; i const NRMat NRMat::operator&(const NRMat & b) const { NRMat result((T)0, nn+b.nn, mm+b.mm); for (int i=0; i const NRMat NRMat::operator|(const NRMat &b) const { NRMat result(nn*b.nn, mm*b.mm); for (int i=0; i const NRVec NRMat::csum() const { NRVec result(nn); T sum; for (int i=0; i const NRVec NRMat::rsum() const { NRVec result(nn); T sum; for (int i=0; i NRMat & NRMat::transposeme() { #ifdef DEBUG if (nn != mm) laerror("transpose of non-square Mat"); #endif copyonwrite(); for(int i=1; i void NRMat::fprintf(FILE *file, const char *format, const int modulo) const { lawritemat(file, (const T*)(*this), nn, mm, format, 2, modulo, 0); } // Input of Mat template void NRMat::fscanf(FILE *f, const char *format) { int n, m; if (std::fscanf(f, "%d %d", &n, &m) != 2) laerror("cannot read matrix dimensions in Mat::fscanf()"); resize(n,m); T *p = *this; for(int i=0; i */ // Mat *= a NRMat & NRMat::operator*=(const double &a) { copyonwrite(); cblas_dscal(nn*mm, a, *this, 1); return *this; } NRMat< complex > & NRMat< complex >::operator*=(const complex &a) { copyonwrite(); cblas_zscal(nn*mm, &a, (void *)(*this)[0], 1); return *this; } //and for general type template NRMat & NRMat::operator*=(const T &a) { copyonwrite(); #ifdef MATPTR for (int i=0; i< nn*nn; i++) v[0][i] *= a; #else for (int i=0; i< nn*nn; i++) v[i] *= a; #endif return *this; } // Mat += Mat NRMat & NRMat::operator+=(const NRMat &rhs) { #ifdef DEBUG if (nn != rhs.nn || mm!= rhs.mm) laerror("Mat += Mat of incompatible matrices"); #endif copyonwrite(); cblas_daxpy(nn*mm, 1.0, rhs, 1, *this, 1); return *this; } NRMat< complex > & NRMat< complex >::operator+=(const NRMat< complex > &rhs) { #ifdef DEBUG if (nn != rhs.nn || mm!= rhs.mm) laerror("Mat += Mat of incompatible matrices"); #endif copyonwrite(); cblas_zaxpy(nn*mm, &CONE, (void *)rhs[0], 1, (void *)(*this)[0], 1); return *this; } //and for general type template NRMat & NRMat::operator+=(const NRMat &rhs) { #ifdef DEBUG if (nn != rhs.nn || mm!= rhs.mm) laerror("Mat -= Mat of incompatible matrices"); #endif copyonwrite(); #ifdef MATPTR for (int i=0; i< nn*nn; i++) v[0][i] += rhs.v[0][i] ; #else for (int i=0; i< nn*nn; i++) v[i] += rhs.v[i] ; #endif return *this; } // Mat -= Mat NRMat & NRMat::operator-=(const NRMat &rhs) { #ifdef DEBUG if (nn != rhs.nn || mm!= rhs.mm) laerror("Mat -= Mat of incompatible matrices"); #endif copyonwrite(); cblas_daxpy(nn*mm, -1.0, rhs, 1, *this, 1); return *this; } NRMat< complex > & NRMat< complex >::operator-=(const NRMat< complex > &rhs) { #ifdef DEBUG if (nn != rhs.nn || mm!= rhs.mm) laerror("Mat -= Mat of incompatible matrices"); #endif copyonwrite(); cblas_zaxpy(nn*mm, &CMONE, (void *)rhs[0], 1, (void *)(*this)[0], 1); return *this; } //and for general type template NRMat & NRMat::operator-=(const NRMat &rhs) { #ifdef DEBUG if (nn != rhs.nn || mm!= rhs.mm) laerror("Mat -= Mat of incompatible matrices"); #endif copyonwrite(); #ifdef MATPTR for (int i=0; i< nn*nn; i++) v[0][i] -= rhs.v[0][i] ; #else for (int i=0; i< nn*nn; i++) v[i] -= rhs.v[i] ; #endif return *this; } // Mat += SMat NRMat & NRMat::operator+=(const NRSMat &rhs) { #ifdef DEBUG if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat+=SMat"); #endif const double *p = rhs; copyonwrite(); for (int i=0; i > & NRMat< complex >::operator+=(const NRSMat< complex > &rhs) { #ifdef DEBUG if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat+=SMat"); #endif const complex *p = rhs; copyonwrite(); for (int i=0; i NRMat & NRMat::operator+=(const NRSMat &rhs) { #ifdef DEBUG if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat+=SMat"); #endif const T *p = rhs; copyonwrite(); for (int i=0; i & NRMat::operator-=(const NRSMat &rhs) { #ifdef DEBUG if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat-=SMat"); #endif const double *p = rhs; copyonwrite(); for (int i=0; i > & NRMat< complex >::operator-=(const NRSMat< complex > &rhs) { #ifdef DEBUG if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat-=SMat"); #endif const complex *p = rhs; copyonwrite(); for (int i=0; i NRMat & NRMat::operator-=(const NRSMat &rhs) { #ifdef DEBUG if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat+=SMat"); #endif const T *p = rhs; copyonwrite(); for (int i=0; i::dot(const NRMat &rhs) const { #ifdef DEBUG if(nn!=rhs.nn || mm!= rhs.mm) laerror("Mat.Mat incompatible matrices"); #endif return cblas_ddot(nn*mm, (*this)[0], 1, rhs[0], 1); } const complex NRMat< complex >::dot(const NRMat< complex > &rhs) const { #ifdef DEBUG if(nn!=rhs.nn || mm!= rhs.mm) laerror("Mat.Mat incompatible matrices"); #endif complex dot; cblas_zdotc_sub(nn*mm, (void *)(*this)[0], 1, (void *)rhs[0], 1, (void *)(&dot)); return dot; } // Mat * Mat const NRMat NRMat::operator*(const NRMat &rhs) const { #ifdef DEBUG if (mm != rhs.nn) laerror("product of incompatible matrices"); #endif NRMat result(nn, rhs.mm); cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, nn, rhs.mm, mm, 1.0, *this, mm, rhs, rhs.mm, 0.0, result, rhs.mm); return result; } const NRMat< complex > NRMat< complex >::operator*(const NRMat< complex > &rhs) const { #ifdef DEBUG if (mm != rhs.nn) laerror("product of incompatible matrices"); #endif NRMat< complex > result(nn, rhs.mm); cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, nn, rhs.mm, mm, (const void *)(&CONE),(const void *)(*this)[0], mm, (const void *)rhs[0], rhs.mm, (const void *)(&CZERO), (void *)result[0], rhs.mm); return result; } // Multiply by diagonal from L void NRMat::diagmultl(const NRVec &rhs) { #ifdef DEBUG if (nn != rhs.size()) laerror("incompatible matrix dimension in diagmultl"); #endif copyonwrite(); for(int i=0; i >::diagmultl(const NRVec< complex > &rhs) { #ifdef DEBUG if (nn != rhs.size()) laerror("incompatible matrix dimension in diagmultl"); #endif copyonwrite(); for (int i=0; i::diagmultr(const NRVec &rhs) { #ifdef DEBUG if (mm != rhs.size()) laerror("incompatible matrix dimension in diagmultr"); #endif copyonwrite(); for (int i=0; i >::diagmultr(const NRVec< complex > &rhs) { #ifdef DEBUG if (mm != rhs.size()) laerror("incompatible matrix dimension in diagmultl"); #endif copyonwrite(); for (int i=0; i NRMat::operator*(const NRSMat &rhs) const { #ifdef DEBUG if (mm != rhs.nrows()) laerror("incompatible dimension in Mat*SMat"); #endif NRMat result(nn, rhs.ncols()); for (int i=0; i > NRMat< complex >::operator*(const NRSMat< complex > &rhs) const { #ifdef DEBUG if (mm != rhs.nrows()) laerror("incompatible dimension in Mat*SMat"); #endif NRMat< complex > result(nn, rhs.ncols()); for (int i=0; i NRMat::operator*(const NRVec &vec) const { #ifdef DEBUG if(mm != vec.size()) laerror("incompatible sizes in Mat*Vec"); #endif NRVec result(nn); cblas_dgemv(CblasRowMajor, CblasNoTrans, nn, mm, 1.0, (*this)[0], mm, &vec[0], 1, 0.0, &result[0], 1); return result; } const NRVec< complex > NRMat< complex >::operator*(const NRVec< complex > &vec) const { #ifdef DEBUG if(mm != vec.size()) laerror("incompatible sizes in Mat*Vec"); #endif NRVec< complex > result(nn); cblas_zgemv(CblasRowMajor, CblasNoTrans, nn, mm, (void *)&CONE, (void *)(*this)[0], mm, (void *)&vec[0], 1, (void *)&CZERO, (void *)&result[0], 1); return result; } // sum of rows const NRVec NRMat::rsum() const { NRVec result(mm); for (int i=0; i NRMat::csum() const { NRVec result(nn); for (int i=0; i &NRMat::conjugateme() {return *this;} NRMat< complex > & NRMat< complex >::conjugateme() { copyonwrite(); cblas_dscal(mm*nn, -1.0, (double *)((*this)[0])+1, 2); return *this; } // transpose and optionally conjugate const NRMat NRMat::transpose(bool conj) const { NRMat result(mm,nn); for(int i=0; i > NRMat< complex >::transpose(bool conj) const { NRMat< complex > result(mm,nn); for (int i=0; i::gemm(const double &beta, const NRMat &a, const char transa, const NRMat &b, const char transb, const double &alpha) { int l(transa=='n'?a.nn:a.mm); int k(transa=='n'?a.mm:a.nn); int kk(transb=='n'?b.nn:b.mm); int ll(transb=='n'?b.mm:b.nn); #ifdef DEBUG if (l!=nn || ll!=mm || k!=kk) laerror("incompatible matrices in Mat:gemm()"); #endif if (alpha==0.0 && beta==1.0) return; copyonwrite(); cblas_dgemm(CblasRowMajor, (transa=='n' ? CblasNoTrans : CblasTrans), (transb=='n' ? CblasNoTrans : CblasTrans), nn, mm, k, alpha, a, a.mm, b , b.mm, beta, *this , mm); } void NRMat< complex >::gemm(const complex & beta, const NRMat< complex > & a, const char transa, const NRMat< complex > & b, const char transb, const complex & alpha) { int l(transa=='n'?a.nn:a.mm); int k(transa=='n'?a.mm:a.nn); int kk(transb=='n'?b.nn:b.mm); int ll(transb=='n'?b.mm:b.nn); #ifdef DEBUG if (l!=nn || ll!=mm || k!=kk) laerror("incompatible matrices in Mat:gemm()"); #endif if (alpha==CZERO && beta==CONE) return; copyonwrite(); cblas_zgemm(CblasRowMajor, (transa=='n' ? CblasNoTrans : (transa=='c'?CblasConjTrans:CblasTrans)), (transb=='n' ? CblasNoTrans : (transa=='c'?CblasConjTrans:CblasTrans)), nn, mm, k, &alpha, a , a.mm, b , b.mm, &beta, *this , mm); } // norm of Mat const double NRMat::norm(const double scalar) const { if (!scalar) return cblas_dnrm2(nn*mm, (*this)[0], 1); double sum = 0; for (int i=0; i >::norm(const complex scalar) const { if (scalar == CZERO) return cblas_dznrm2(nn*mm, (*this)[0], 1); double sum = 0; for (int i=0; i tmp; #ifdef MATPTR tmp = v[i][j]; #else tmp = v[i*mm+j]; #endif if (i==j) tmp -= scalar; sum += tmp.real()*tmp.real()+tmp.imag()*tmp.imag(); } return sqrt(sum); } // axpy: this = a * Mat void NRMat::axpy(const double alpha, const NRMat &mat) { #ifdef DEBUG if (nn!=mat.nn || mm!=mat.mm) laerror("daxpy of incompatible matrices"); #endif copyonwrite(); cblas_daxpy(nn*mm, alpha, mat, 1, *this, 1); } void NRMat< complex >::axpy(const complex alpha, const NRMat< complex > & mat) { #ifdef DEBUG if (nn!=mat.nn || mm!=mat.mm) laerror("zaxpy of incompatible matrices"); #endif copyonwrite(); cblas_zaxpy(nn*mm, (void *)&alpha, mat, 1, (void *)(*this)[0], 1); } // trace of Mat const double NRMat::trace() const { #ifdef DEBUG if (nn != mm) laerror("no-square matrix in Mat::trace()"); #endif return cblas_dasum(nn, (*this)[0], nn+1); } const complex NRMat< complex >::trace() const { #ifdef DEBUG if (nn != mm) laerror("no-square matrix in Mat::trace()"); #endif register complex sum = CZERO; for (int i=0; i::diagonalof(NRVec &r, const bool divide) const { #ifdef DEBUG if (r.size() != nn) laerror("diagonalof() incompatible vector"); #endif double a; r.copyonwrite(); if(nn==mm) { #ifdef MATPTR if(divide) for (int i=0; i< nn; i++) if((a=v[i][i])) r[i]/=a; else for (int i=0; i< nn; i++) r[i] = v[i][i]; #else if(divide) {int i,j; for (i=j=0; j< nn; ++j, i+=nn+1) if((a=v[i])) r[j] /=a;} else {int i,j; for (i=j=0; j< nn; ++j, i+=nn+1) r[j] = v[i];} #endif } else //non-square { for (int i=0; i< mm; i++) { #ifdef MATPTR a= cblas_ddot(nn,v[0]+i,mm,v[0]+i,mm); #else a=cblas_ddot(nn,v+i,mm,v+i,mm); #endif if(divide) {if(a) r[i]/=a;} else r[i] = a; } } } ////////////////////////////////////////////////////////////////////////////// //// forced instantization in the corespoding object file #define INSTANTIZE(T) \ template ostream & operator<<(ostream &s, const NRMat< T > &x); \ template istream & operator>>(istream &s, NRMat< T > &x); \ INSTANTIZE(double) INSTANTIZE(complex) INSTANTIZE(int) INSTANTIZE(char) export template ostream& operator<<(ostream &s, const NRMat &x) { int i,j,n,m; n=x.nrows(); m=x.ncols(); s << n << ' ' << m << '\n'; for(i=0;i istream& operator>>(istream &s, NRMat &x) { int i,j,n,m; s >> n >> m; x.resize(n,m); for(i=0;i>x[i][j] ; return s; }