#include "mat.h" // TODO : // ////////////////////////////////////////////////////////////////////////////// //// forced instantization in the corresponding object file template NRMat; template NRMat< complex >; /* * Templates first, specializations for BLAS next */ // dtor template NRMat::~NRMat() { if (!count) return; if (--(*count) <= 0) { if (v) { #ifdef MATPTR delete[] (v[0]); #endif delete[] v; } delete count; } } // assign NRMat = NRMat template NRMat & NRMat::operator=(const NRMat &rhs) { if (this == &rhs) return *this; if (count) { if (--(*count) ==0 ) { #ifdef MATPTR delete[] (v[0]); #endif delete[] v; delete count; } v = rhs.v; nn = rhs.nn; mm = rhs.mm; count = rhs.count; if (count) (*count)--; } return *this; } // Assign diagonal template NRMat & NRMat::operator=(const T &a) { copyonwrite(); #ifdef DEBUG if (nn != mm) laerror("RMat.operator=scalar on non-square matrix"); #endif #ifdef MATPTR for (int i=0; i< nn; i++) v[i][i] = a; #else for (int i=0; i< nn*nn; i+=nn+1) v[i] = a; #endif return *this; } // Explicit deep copy of NRmat template NRMat & NRMat::operator|=(const NRMat &rhs) { if (this == &rhs) return *this; #ifdef DEBUG if (!rhs.v) laerror("unallocated rhs in Mat operator |="); #endif if (count) if (*count > 1) { --(*count); nn = 0; mm = 0; count = 0; v = 0; } if (nn != rhs.nn || mm != rhs.mm) { if (v) { #ifdef MATPTR delete[] (v[0]); #endif delete[] (v); v = 0; } nn = rhs.nn; mm = rhs.mm; } if (!v) { #ifdef MATPTR v = new T*[nn]; v[0] = new T[mm*nn]; #else v = new T[mm*nn]; #endif } #ifdef MATPTR for (int i=1; i< nn; i++) v[i] = v[i-1] + mm; memcpy(v[0], rhs.v[0], nn*mm*sizeof(T)); #else memcpy(v, rhs.v, nn*mm*sizeof(T)); #endif if (!count) count = new int; *count = 1; return *this; } // M += a template NRMat & NRMat::operator+=(const T &a) { copyonwrite(); #ifdef DEBUG if (nn != mm) laerror("Mat.operator+=scalar on non-square matrix"); #endif #ifdef MATPTR for (int i=0; i< nn; i++) v[i][i] += a; #else for (int i=0; i< nn*nn; i+=nn+1) v[i] += a; #endif return *this; } // M -= a template NRMat & NRMat::operator-=(const T &a) { copyonwrite(); #ifdef DEBUG if (nn != mm) laerror("Mat.operator-=scalar on non-square matrix"); #endif #ifdef MATPTR for (int i=0; i< nn; i++) v[i][i] -= a; #else for (int i=0; i< nn*nn; i+=nn+1) v[i] -= a; #endif return *this; } // unary minus template const NRMat NRMat::operator-() const { NRMat result(nn, mm); #ifdef MATPTR for (int i=0; i const NRMat NRMat::operator&(const NRMat & b) const { NRMat result((T)0, nn+b.nn, mm+b.mm); for (int i=0; i const NRMat NRMat::operator|(const NRMat &b) const { NRMat result(nn*b.nn, mm*b.mm); for (int i=0; i const NRVec NRMat::csum() const { NRVec result(nn); T sum; for (int i=0; i const NRVec NRMat::rsum() const { NRVec result(nn); T sum; for (int i=0; i void NRMat::copyonwrite() { #ifdef DEBUG if (!count) laerror("Mat::copyonwrite of undefined matrix"); #endif if (*count > 1) { (*count)--; count = new int; *count = 1; #ifdef MATPTR T **newv = new T*[nn]; newv[0] = new T[mm*nn]; memcpy(newv[0], v[0], mm*nn*sizeof(T)); v = newv; for (int i=1; i< nn; i++) v[i] = v[i-1] + mm; #else T *newv = new T[mm*nn]; memcpy(newv, v, mm*nn*sizeof(T)); v = newv; #endif } } template void NRMat::resize(const int n, const int m) { #ifdef DEBUG if (n<=0 || m<=0) laerror("illegal dimensions in Mat::resize()"); #endif if (count) if (*count > 1) { (*count)--; count = 0; v = 0; nn = 0; mm = 0; } if (!count) { count = new int; *count = 1; nn = n; mm = m; #ifdef MATPTR v = new T*[nn]; v[0] = new T[m*n]; for (int i=1; i< n; i++) v[i] = v[i-1] + m; #else v = new T[m*n]; #endif return; } // At this point *count = 1, check if resize is necessary if (n!=nn || m!=mm) { nn = n; mm = m; #ifdef MATPTR delete[] (v[0]); #endif delete[] v; #ifdef MATPTR v = new T*[nn]; v[0] = new T[m*n]; for (int i=1; i< n; i++) v[i] = v[i-1] + m; #else v = new T[m*n]; #endif } } // transpose Mat template NRMat & NRMat::transposeme() { #ifdef DEBUG if (nn != mm) laerror("transpose of non-square Mat"); #endif copyonwrite(); for(int i=1; i void NRMat::fprintf(FILE *file, const char *format, const int modulo) const { lawritemat(file, (const T*)(*this), nn, mm, format, 2, modulo, 0); } // Input of Mat template void NRMat::fscanf(FILE *f, const char *format) { int n, m; if (std::fscanf(f, "%d %d", &n, &m) != 2) laerror("cannot read matrix dimensions in Mat::fscanf()"); resize(n,m); T *p = *this; for(int i=0; i */ // Mat *= a NRMat & NRMat::operator*=(const double &a) { copyonwrite(); cblas_dscal(nn*mm, a, *this, 1); return *this; } NRMat< complex > & NRMat< complex >::operator*=(const complex &a) { copyonwrite(); cblas_zscal(nn*mm, &a, (void *)(*this)[0], 1); return *this; } // Mat += Mat NRMat & NRMat::operator+=(const NRMat &rhs) { #ifdef DEBUG if (nn != rhs.nn || mm!= rhs.mm) laerror("Mat += Mat of incompatible matrices"); #endif copyonwrite(); cblas_daxpy(nn*mm, 1.0, rhs, 1, *this, 1); return *this; } NRMat< complex > & NRMat< complex >::operator+=(const NRMat< complex > &rhs) { #ifdef DEBUG if (nn != rhs.nn || mm!= rhs.mm) laerror("Mat += Mat of incompatible matrices"); #endif copyonwrite(); cblas_zaxpy(nn*mm, &CONE, (void *)rhs[0], 1, (void *)(*this)[0], 1); return *this; } // Mat -= Mat NRMat & NRMat::operator-=(const NRMat &rhs) { #ifdef DEBUG if (nn != rhs.nn || mm!= rhs.mm) laerror("Mat -= Mat of incompatible matrices"); #endif copyonwrite(); cblas_daxpy(nn*mm, -1.0, rhs, 1, *this, 1); return *this; } NRMat< complex > & NRMat< complex >::operator-=(const NRMat< complex > &rhs) { #ifdef DEBUG if (nn != rhs.nn || mm!= rhs.mm) laerror("Mat -= Mat of incompatible matrices"); #endif copyonwrite(); cblas_zaxpy(nn*mm, &CMONE, (void *)rhs[0], 1, (void *)(*this)[0], 1); return *this; } // Mat += SMat NRMat & NRMat::operator+=(const NRSMat &rhs) { #ifdef DEBUG if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat+=SMat"); #endif const double *p = rhs; copyonwrite(); for (int i=0; i > & NRMat< complex >::operator+=(const NRSMat< complex > &rhs) { #ifdef DEBUG if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat+=SMat"); #endif const complex *p = rhs; copyonwrite(); for (int i=0; i & NRMat::operator-=(const NRSMat &rhs) { #ifdef DEBUG if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat-=SMat"); #endif const double *p = rhs; copyonwrite(); for (int i=0; i > & NRMat< complex >::operator-=(const NRSMat< complex > &rhs) { #ifdef DEBUG if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat-=SMat"); #endif const complex *p = rhs; copyonwrite(); for (int i=0; i::dot(const NRMat &rhs) const { #ifdef DEBUG if(nn!=rhs.nn || mm!= rhs.mm) laerror("Mat.Mat incompatible matrices"); #endif return cblas_ddot(nn*mm, (*this)[0], 1, rhs[0], 1); } const complex NRMat< complex >::dot(const NRMat< complex > &rhs) const { #ifdef DEBUG if(nn!=rhs.nn || mm!= rhs.mm) laerror("Mat.Mat incompatible matrices"); #endif complex dot; cblas_zdotc_sub(nn*mm, (void *)(*this)[0], 1, (void *)rhs[0], 1, (void *)(&dot)); return dot; } // Mat * Mat const NRMat NRMat::operator*(const NRMat &rhs) const { #ifdef DEBUG if (mm != rhs.nn) laerror("product of incompatible matrices"); #endif NRMat result(nn, rhs.mm); cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, nn, rhs.mm, mm, 1.0, *this, mm, rhs, rhs.mm, 0.0, result, rhs.mm); return result; } const NRMat< complex > NRMat< complex >::operator*(const NRMat< complex > &rhs) const { #ifdef DEBUG if (mm != rhs.nn) laerror("product of incompatible matrices"); #endif NRMat< complex > result(nn, rhs.mm); cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, nn, rhs.mm, mm, (const void *)(&CONE),(const void *)(*this)[0], mm, (const void *)rhs[0], rhs.mm, (const void *)(&CZERO), (void *)result[0], rhs.mm); return result; } // Multiply by diagonal from L void NRMat::diagmultl(const NRVec &rhs) { #ifdef DEBUG if (nn != rhs.size()) laerror("incompatible matrix dimension in diagmultl"); #endif copyonwrite(); for(int i=0; i >::diagmultl(const NRVec< complex > &rhs) { #ifdef DEBUG if (nn != rhs.size()) laerror("incompatible matrix dimension in diagmultl"); #endif copyonwrite(); for (int i=0; i::diagmultr(const NRVec &rhs) { #ifdef DEBUG if (mm != rhs.size()) laerror("incompatible matrix dimension in diagmultr"); #endif copyonwrite(); for (int i=0; i >::diagmultr(const NRVec< complex > &rhs) { #ifdef DEBUG if (mm != rhs.size()) laerror("incompatible matrix dimension in diagmultl"); #endif copyonwrite(); for (int i=0; i NRMat::operator*(const NRSMat &rhs) const { #ifdef DEBUG if (mm != rhs.nrows()) laerror("incompatible dimension in Mat*SMat"); #endif NRMat result(nn, rhs.ncols()); for (int i=0; i > NRMat< complex >::operator*(const NRSMat< complex > &rhs) const { #ifdef DEBUG if (mm != rhs.nrows()) laerror("incompatible dimension in Mat*SMat"); #endif NRMat< complex > result(nn, rhs.ncols()); for (int i=0; i NRMat::operator*(const NRVec &vec) const { #ifdef DEBUG if(mm != vec.size()) laerror("incompatible sizes in Mat*Vec"); #endif NRVec result(nn); cblas_dgemv(CblasRowMajor, CblasNoTrans, nn, mm, 1.0, (*this)[0], mm, &vec[0], 1, 0.0, &result[0], 1); return result; } const NRVec< complex > NRMat< complex >::operator*(const NRVec< complex > &vec) const { #ifdef DEBUG if(mm != vec.size()) laerror("incompatible sizes in Mat*Vec"); #endif NRVec< complex > result(nn); cblas_zgemv(CblasRowMajor, CblasNoTrans, nn, mm, (void *)&CONE, (void *)(*this)[0], mm, (void *)&vec[0], 1, (void *)&CZERO, (void *)&result[0], 1); return result; } // sum of rows const NRVec NRMat::rsum() const { NRVec result(mm); for (int i=0; i NRMat::csum() const { NRVec result(nn); for (int i=0; i &NRMat::conjugateme() {return *this;} NRMat< complex > & NRMat< complex >::conjugateme() { copyonwrite(); cblas_dscal(mm*nn, -1.0, (double *)((*this)[0])+1, 2); return *this; } // transpose and optionally conjugate const NRMat NRMat::transpose(bool conj) const { NRMat result(mm,nn); for(int i=0; i > NRMat< complex >::transpose(bool conj) const { NRMat< complex > result(mm,nn); for (int i=0; i::gemm(const double &beta, const NRMat &a, const char transa, const NRMat &b, const char transb, const double &alpha) { int l(transa=='n'?a.nn:a.mm); int k(transa=='n'?a.mm:a.nn); int kk(transb=='n'?b.nn:b.mm); int ll(transb=='n'?b.mm:b.nn); #ifdef DEBUG if (l!=nn || ll!=mm || k!=kk) laerror("incompatible matrices in Mat:gemm()"); #endif if (alpha==0.0 && beta==1.0) return; copyonwrite(); cblas_dgemm(CblasRowMajor, (transa=='n' ? CblasNoTrans : CblasTrans), (transb=='n' ? CblasNoTrans : CblasTrans), nn, mm, k, alpha, a, a.mm, b , b.mm, beta, *this , mm); } void NRMat< complex >::gemm(const complex & beta, const NRMat< complex > & a, const char transa, const NRMat< complex > & b, const char transb, const complex & alpha) { int l(transa=='n'?a.nn:a.mm); int k(transa=='n'?a.mm:a.nn); int kk(transb=='n'?b.nn:b.mm); int ll(transb=='n'?b.mm:b.nn); #ifdef DEBUG if (l!=nn || ll!=mm || k!=kk) laerror("incompatible matrices in Mat:gemm()"); #endif if (alpha==CZERO && beta==CONE) return; copyonwrite(); cblas_zgemm(CblasRowMajor, (transa=='n' ? CblasNoTrans : (transa=='c'?CblasConjTrans:CblasTrans)), (transb=='n' ? CblasNoTrans : (transa=='c'?CblasConjTrans:CblasTrans)), nn, mm, k, &alpha, a , a.mm, b , b.mm, &beta, *this , mm); } // norm of Mat const double NRMat::norm(const double scalar) const { if (!scalar) return cblas_dnrm2(nn*mm, (*this)[0], 1); double sum = 0; for (int i=0; i >::norm(const complex scalar) const { if (scalar == CZERO) return cblas_dznrm2(nn*mm, (*this)[0], 1); double sum = 0; for (int i=0; i tmp; #ifdef MATPTR tmp = v[i][j]; #else tmp = v[i*mm+j]; #endif if (i==j) tmp -= scalar; sum += tmp.real()*tmp.real()+tmp.imag()*tmp.imag(); } return sqrt(sum); } // axpy: this = a * Mat void NRMat::axpy(const double alpha, const NRMat &mat) { #ifdef DEBUG if (nn!=mat.nn || mm!=mat.mm) laerror("daxpy of incompatible matrices"); #endif copyonwrite(); cblas_daxpy(nn*mm, alpha, mat, 1, *this, 1); } void NRMat< complex >::axpy(const complex alpha, const NRMat< complex > & mat) { #ifdef DEBUG if (nn!=mat.nn || mm!=mat.mm) laerror("zaxpy of incompatible matrices"); #endif copyonwrite(); cblas_zaxpy(nn*mm, (void *)&alpha, mat, 1, (void *)(*this)[0], 1); } // trace of Mat const double NRMat::trace() const { #ifdef DEBUG if (nn != mm) laerror("no-square matrix in Mat::trace()"); #endif return cblas_dasum(nn, (*this)[0], nn+1); } const complex NRMat< complex >::trace() const { #ifdef DEBUG if (nn != mm) laerror("no-square matrix in Mat::trace()"); #endif register complex sum = CZERO; for (int i=0; i &x); \ template istream & operator>>(istream &s, NRMat< T > &x); \ INSTANTIZE(double) INSTANTIZE(complex) export template ostream& operator<<(ostream &s, const NRMat &x) { int i,j,n,m; n=x.nrows(); m=x.ncols(); s << n << ' ' << m << '\n'; for(i=0;i istream& operator>>(istream &s, NRMat &x) { int i,j,n,m; s >> n >> m; x.resize(n,m); for(i=0;i>x[i][j] ; return s; }