#ifndef _LA_MAT_H_ #define _LA_MAT_H_ #include "vec.h" #include "smat.h" template class NRMat { protected: int nn; int mm; #ifdef MATPTR T **v; #else T *v; #endif int *count; public: friend class NRVec; friend class NRSMat; inline NRMat() : nn(0), mm(0), v(0), count(0) {}; inline NRMat(const int n, const int m); inline NRMat(const T &a, const int n, const int m); NRMat(const T *a, const int n, const int m); inline NRMat(const NRMat &rhs); explicit NRMat(const NRSMat &rhs); #ifndef MATPTR NRMat(const NRVec &rhs, const int n, const int m); #endif ~NRMat(); inline int getcount() const {return count?*count:0;} NRMat & operator=(const NRMat &rhs); //assignment NRMat & operator=(const T &a); //assign a to diagonal NRMat & operator|=(const NRMat &rhs); //assignment to a new copy NRMat & operator+=(const T &a); //add diagonal NRMat & operator-=(const T &a); //substract diagonal NRMat & operator*=(const T &a); //multiply by a scalar NRMat & operator+=(const NRMat &rhs); NRMat & operator-=(const NRMat &rhs); NRMat & operator+=(const NRSMat &rhs); NRMat & operator-=(const NRSMat &rhs); const NRMat operator-() const; //unary minus inline const NRMat operator+(const T &a) const; inline const NRMat operator-(const T &a) const; inline const NRMat operator*(const T &a) const; inline const NRMat operator+(const NRMat &rhs) const; inline const NRMat operator-(const NRMat &rhs) const; inline const NRMat operator+(const NRSMat &rhs) const; inline const NRMat operator-(const NRSMat &rhs) const; const T dot(const NRMat &rhs) const; // scalar product of Mat.Mat const NRMat operator*(const NRMat &rhs) const; // Mat * Mat void diagmultl(const NRVec &rhs); //multiply by a diagonal matrix from L void diagmultr(const NRVec &rhs); //multiply by a diagonal matrix from R const NRMat operator*(const NRSMat &rhs) const; // Mat * Smat const NRMat operator&(const NRMat &rhs) const; // direct sum const NRMat operator|(const NRMat &rhs) const; // direct product const NRVec operator*(const NRVec &rhs) const; // Mat * Vec const NRVec rsum() const; //sum of rows const NRVec csum() const; //sum of columns inline T* operator[](const int i); //subscripting: pointer to row i inline const T* operator[](const int i) const; inline T& operator()(const int i, const int j); // (i,j) subscripts inline const T& operator()(const int i, const int j) const; inline int nrows() const; inline int ncols() const; void copyonwrite(); void resize(const int n, const int m); inline operator T*(); //get a pointer to the data inline operator const T*() const; NRMat & transposeme(); // square matrices only NRMat & conjugateme(); // square matrices only const NRMat transpose(bool conj=false) const; const NRMat conjugate() const; void gemm(const T &beta, const NRMat &a, const char transa, const NRMat &b, const char transb, const T &alpha);//this = alpha*op( A )*op( B ) + beta*this /* void strassen(const T beta, const NRMat &a, const char transa, const NRMat &b, const char transb, const T alpha);//this := alpha*op( A )*op( B ) + beta*this void s_cutoff(const int,const int,const int,const int) const; */ void fprintf(FILE *f, const char *format, const int modulo) const; void fscanf(FILE *f, const char *format); const double norm(const T scalar=(T)0) const; void axpy(const T alpha, const NRMat &x); // this += a*x inline const T amax() const; const T trace() const; //members concerning sparse matrix explicit NRMat(const SparseMat &rhs); // dense from sparse NRMat & operator+=(const SparseMat &rhs); NRMat & operator-=(const SparseMat &rhs); inline void simplify() {}; //just for compatibility with sparse ones //Strassen's multiplication (better than n^3, analogous syntax to gemm) void strassen(const T beta, const NRMat &a, const char transa, const NRMat &b, const char transb, const T alpha);//this := alpha*op( A )*op( B ) + beta*this void s_cutoff(const int,const int,const int,const int) const; }; // ctors template NRMat::NRMat(const int n, const int m) : nn(n), mm(m), count(new int) { *count = 1; #ifdef MATPTR v = new T*[n]; v[0] = new T[m*n]; for (int i=1; i NRMat::NRMat(const T &a, const int n, const int m) : nn(n), mm(m), count(new int) { int i; T *p; *count = 1; #ifdef MATPTR v = new T*[n]; p = v[0] = new T[m*n]; for (int i=1; i NRMat::NRMat(const T *a, const int n, const int m) : nn(n), mm(m), count(new int) { *count = 1; #ifdef MATPTR v = new T*[n]; v[0] = new T[m*n]; for (int i=1; i NRMat::NRMat(const NRMat &rhs) { nn = rhs.nn; mm = rhs.mm; count = rhs.count; v = rhs.v; if (count) ++(*count); } template NRMat::NRMat(const NRSMat &rhs) { int i; nn = mm = rhs.nrows(); count = new int; *count = 1; #ifdef MATPTR v = new T*[nn]; v[0] = new T[mm*nn]; for (int i=1; i NRMat::NRMat(const NRVec &rhs, const int n, const int m) { #ifdef DEBUG if (n*m != rhs.nn) laerror("matrix dimensions incompatible with vector length"); #endif nn = n; mm = m; count = rhs.count; v = rhs.v; (*count)++; } #endif // Mat + Smat template inline const NRMat NRMat::operator+(const NRSMat &rhs) const { return NRMat(*this) += rhs; } // Mat - Smat template inline const NRMat NRMat::operator-(const NRSMat &rhs) const { return NRMat(*this) -= rhs; } // Mat[i] : pointer to the first element of i-th row template inline T* NRMat::operator[](const int i) { #ifdef DEBUG if (*count != 1) laerror("Mat lval use of [] with count > 1"); if (i<0 || i>=nn) laerror("Mat [] out of range"); if (!v) laerror("[] for unallocated Mat"); #endif #ifdef MATPTR return v[i]; #else return v+i*mm; #endif } template inline const T* NRMat::operator[](const int i) const { #ifdef DEBUG if (i<0 || i>=nn) laerror("Mat [] out of range"); if (!v) laerror("[] for unallocated Mat"); #endif #ifdef MATPTR return v[i]; #else return v+i*mm; #endif } // Mat(i,j) reference to the matrix element M_{ij} template inline T & NRMat::operator()(const int i, const int j) { #ifdef DEBUG if (*count != 1) laerror("Mat lval use of (,) with count > 1"); if (i<0 || i>=nn || j<0 || j>mm) laerror("Mat (,) out of range"); if (!v) laerror("(,) for unallocated Mat"); #endif #ifdef MATPTR return v[i][j]; #else return v[i*mm+j]; #endif } template inline const T & NRMat::operator()(const int i, const int j) const { #ifdef DEBUG if (i<0 || i>=nn || j<0 || j>mm) laerror("Mat (,) out of range"); if (!v) laerror("(,) for unallocated Mat"); #endif #ifdef MATPTR return v[i][j]; #else return v[i*mm+j]; #endif } // number of rows template inline int NRMat::nrows() const { return nn; } // number of columns template inline int NRMat::ncols() const { return mm; } // reference pointer to Mat template inline NRMat::operator T* () { #ifdef DEBUG if (!v) laerror("unallocated Mat in operator T*"); #endif #ifdef MATPTR return v[0]; #else return v; #endif } template inline NRMat::operator const T* () const { #ifdef DEBUG if (!v) laerror("unallocated Mat in operator T*"); #endif #ifdef MATPTR return v[0]; #else return v; #endif } // max element of Mat inline const double NRMat::amax() const { #ifdef MATPTR return v[0][cblas_idamax(nn*mm, v[0], 1)]; #else return v[cblas_idamax(nn*mm, v, 1)]; #endif } inline const complex NRMat< complex >::amax() const { #ifdef MATPTR return v[0][cblas_izamax(nn*mm, (void *)v[0], 1)]; #else return v[cblas_izamax(nn*mm, (void *)v, 1)]; #endif } // I/O template extern ostream& operator<<(ostream &s, const NRMat &x); template extern istream& operator>>(istream &s, NRMat &x); // generate operators: Mat + a, a + Mat, Mat * a NRVECMAT_OPER(Mat,+) NRVECMAT_OPER(Mat,-) NRVECMAT_OPER(Mat,*) // generate Mat + Mat, Mat - Mat NRVECMAT_OPER2(Mat,+) NRVECMAT_OPER2(Mat,-) #endif /* _LA_MAT_H_ */