#ifndef _LA_VEC_H_ #define _LA_VEC_H_ #include "laerror.h" extern "C" { #include "cblas.h" } #include #include #include #include using namespace std; template class NRVec; template class NRSMat; template class NRMat; template class SparseMat; ////////////////////////////////////////////////////////////////////////////// // Forward declarations template void lawritemat(FILE *file,const T *a,int r,int c, const char *form0,int nodim,int modulo, int issym); // Memory allocated constants for cblas routines const static complex CONE = 1.0, CMONE = -1.0, CZERO = 0.0; // Macros to construct binary operators +,-,*, from +=, -=, *= // for 3 cases: X + a, a + X, X + Y #define NRVECMAT_OPER(E,X) \ template \ inline const NR##E NR##E::operator X(const T &a) const \ { return NR##E(*this) X##= a; } \ \ template \ inline const NR##E operator X(const T &a, const NR##E &rhs) \ { return NR##E(rhs) X##= a; } #define NRVECMAT_OPER2(E,X) \ template \ inline const NR##E NR##E::operator X(const NR##E &a) const \ { return NR##E(*this) X##= a; } #include "smat.h" #include "mat.h" // NRVec class template class NRVec { protected: int nn; T *v; int *count; public: friend class NRSMat; friend class NRMat; inline NRVec(): nn(0),v(0),count(0){}; explicit inline NRVec(const int n) : nn(n), v(new T[n]), count(new int(1)) {}; inline NRVec(const T &a, const int n); inline NRVec(const T *a, const int n); inline NRVec(const NRVec &rhs); inline explicit NRVec(const NRSMat & S); #ifndef MATPTR explicit NRVec(const NRMat &rhs); #endif NRVec & operator=(const NRVec &rhs); NRVec & operator=(const T &a); //assign a to every element NRVec & operator|=(const NRVec &rhs); const NRVec operator-() const; inline NRVec & operator+=(const NRVec &rhs); inline NRVec & operator-=(const NRVec &rhs); inline NRVec & operator+=(const T &a); inline NRVec & operator-=(const T &a); inline NRVec & operator*=(const T &a); inline int getcount() const {return count?*count:0;} inline const NRVec operator+(const NRVec &rhs) const; inline const NRVec operator-(const NRVec &rhs) const; inline const NRVec operator+(const T &a) const; inline const NRVec operator-(const T &a) const; inline const NRVec operator*(const T &a) const; inline const T operator*(const NRVec &rhs) const; //scalar product -> ddot inline const NRVec operator*(const NRSMat & S) const; const NRVec operator*(const NRMat &mat) const; const NRMat operator|(const NRVec &rhs) const; inline const T sum() const; //sum of its elements inline const T dot(const T *a, const int stride=1) const; // ddot with a stride-vector inline T & operator[](const int i); inline const T & operator[](const int i) const; inline int size() const; inline operator T*(); //get a pointer to the data inline operator const T*() const; //get a pointer to the data ~NRVec(); void axpy(const T alpha, const NRVec &x); // this+= a*x void axpy(const T alpha, const T *x, const int stride=1); // this+= a*x void gemv(const T beta, const NRMat &a, const char trans, const T alpha, const NRVec &x); void copyonwrite(); void resize(const int n); NRVec & normalize(); inline const double norm() const; inline const T amax() const; inline const NRVec unitvector() const; void fprintf(FILE *f, const char *format, const int modulo) const; void fscanf(FILE *f, const char *format); //sparse matrix concerning members explicit NRVec(const SparseMat &rhs); // dense from sparse matrix with one of dimensions =1 const NRVec operator*(const SparseMat &mat) const; //vector*matrix inline void simplify() {}; //just for compatibility with sparse ones void gemv(const T beta, const SparseMat &a, const char trans, const T alpha, const NRVec &x); }; template ostream & operator<<(ostream &s, const NRVec &x); template istream & operator>>(istream &s, NRVec &x); // INLINES // ctors template inline NRVec::NRVec(const T& a, const int n) : nn(n), v(new T[n]), count(new int) { *count = 1; if(a != (T)0) for(int i=0; i inline NRVec::NRVec(const T *a, const int n) : nn(n), v(new T[n]), count(new int) { *count = 1; memcpy(v, a, n*sizeof(T)); } template inline NRVec::NRVec(const NRVec &rhs) { v = rhs.v; nn = rhs.nn; count = rhs.count; if(count) (*count)++; } template inline NRVec::NRVec(const NRSMat &rhs) { nn = rhs.nn; nn = NN2; v = rhs.v; count = rhs.count; (*count)++; } // x += a inline NRVec & NRVec::operator+=(const double &a) { copyonwrite(); cblas_daxpy(nn, 1.0, &a, 0, v, 1); return *this; } inline NRVec< complex > & NRVec< complex >::operator+=(const complex &a) { copyonwrite(); cblas_zaxpy(nn, (void *)(&CONE), (void *)(&a), 0, (void *)v, 1); return *this; } //and for general type template inline NRVec & NRVec::operator+=(const T &a) { copyonwrite(); int i; for(i=0; i & NRVec::operator-=(const double &a) { copyonwrite(); cblas_daxpy(nn, 1.0, &a, 0, v, 1); return *this; } inline NRVec< complex > & NRVec< complex >::operator-=(const complex &a) { copyonwrite(); cblas_zaxpy(nn, (void *)(&CMONE), (void *)(&a), 0, (void *)v, 1); return *this; } //and for general type template inline NRVec & NRVec::operator-=(const T &a) { copyonwrite(); int i; for(i=0; i & NRVec::operator+=(const NRVec &rhs) { #ifdef DEBUG if (nn != rhs.nn) laerror("daxpy of incompatible vectors"); #endif copyonwrite(); cblas_daxpy(nn, 1.0, rhs.v, 1, v, 1); return *this; } inline NRVec< complex > & NRVec< complex >::operator+=(const NRVec< complex > &rhs) { #ifdef DEBUG if (nn != rhs.nn) laerror("daxpy of incompatible vectors"); #endif copyonwrite(); cblas_zaxpy(nn, (void *)(&CONE), rhs.v, 1, v, 1); return *this; } //and for general type template inline NRVec & NRVec::operator+=(const NRVec &rhs) { #ifdef DEBUG if (nn != rhs.nn) laerror("daxpy of incompatible vectors"); #endif copyonwrite(); int i; for(i=0; i & NRVec::operator-=(const NRVec &rhs) { #ifdef DEBUG if (nn != rhs.nn) laerror("daxpy of incompatible vectors"); #endif copyonwrite(); cblas_daxpy(nn, -1.0, rhs.v, 1, v, 1); return *this; } inline NRVec< complex > & NRVec< complex >::operator-=(const NRVec< complex > &rhs) { #ifdef DEBUG if (nn != rhs.nn) laerror("daxpy of incompatible vectors"); #endif copyonwrite(); cblas_zaxpy(nn, (void *)(&CMONE), (void *)rhs.v, 1, (void *)v, 1); return *this; } //and for general type template inline NRVec & NRVec::operator-=(const NRVec &rhs) { #ifdef DEBUG if (nn != rhs.nn) laerror("daxpy of incompatible vectors"); #endif copyonwrite(); int i; for(i=0; i & NRVec::operator*=(const double &a) { copyonwrite(); cblas_dscal(nn, a, v, 1); return *this; } inline NRVec< complex > & NRVec< complex >::operator*=(const complex &a) { copyonwrite(); cblas_zscal(nn, (void *)(&a), (void *)v, 1); return *this; } //and for general type template inline NRVec & NRVec::operator*=(const T &a) { copyonwrite(); int i; for(i=0; i::operator*(const NRVec &rhs) const { #ifdef DEBUG if (nn != rhs.nn) laerror("ddot of incompatible vectors"); #endif return cblas_ddot(nn, v, 1, rhs.v, 1); } inline const complex NRVec< complex >::operator*(const NRVec< complex > &rhs) const { #ifdef DEBUG if (nn != rhs.nn) laerror("ddot of incompatible vectors"); #endif complex dot; cblas_zdotc_sub(nn, (void *)v, 1, (void *)rhs.v, 1, (void *)(&dot)); return dot; } // Vec * SMat = SMat * Vec template inline const NRVec NRVec::operator*(const NRSMat & S) const { return S * (*this); } // Sum of elements inline const double NRVec::sum() const { return cblas_dasum(nn, v, 1); } inline const complex NRVec< complex >::sum() const { complex sum = CZERO; for (int i=0; i::dot(const double *y, const int stride) const { return cblas_ddot(nn, y, stride, v, 1); } inline const complex NRVec< complex >::dot(const complex *y, const int stride) const { complex dot; cblas_zdotc_sub(nn, y, stride, v, 1, (void *)(&dot)); return dot; } // x[i] returns i-th element template inline T & NRVec::operator[](const int i) { #ifdef DEBUG if(*count != 1) laerror("possible lval [] with count > 1"); if(i < 0 || i >= nn) laerror("NRVec out of range"); if(!v) laerror("[] on unallocated NRVec"); #endif return v[i]; } template inline const T & NRVec::operator[](const int i) const { #ifdef DEBUG if(i < 0 || i >= nn) laerror("NRVec out of range"); if(!v) laerror("[] on unallocated NRVec"); #endif return v[i]; } // length of the vector template inline int NRVec::size() const { return nn; } // reference Vec to the first element template inline NRVec::operator T*() { #ifdef DEBUG if(!v) laerror("unallocated NRVec in operator T*"); #endif return v; } template inline NRVec::operator const T*() const { #ifdef DEBUG if(!v) laerror("unallocated NRVec in operator T*"); #endif return v; } // return norm of the Vec inline const double NRVec::norm() const { return cblas_dnrm2(nn, v, 1); } inline const double NRVec< complex >::norm() const { return cblas_dznrm2(nn, (void *)v, 1); } // Max element of the array inline const double NRVec::amax() const { return v[cblas_idamax(nn, v, 1)]; } inline const complex NRVec< complex >::amax() const { return v[cblas_izamax(nn, (void *)v, 1)]; } // Make Vec unitvector template inline const NRVec NRVec::unitvector() const { return NRVec(*this).normalize(); } // generate operators: Vec + a, a + Vec, Vec * a NRVECMAT_OPER(Vec,+) NRVECMAT_OPER(Vec,-) NRVECMAT_OPER(Vec,*) // generate operators: Vec + Vec, Vec - Vec NRVECMAT_OPER2(Vec,+) NRVECMAT_OPER2(Vec,-) // Few forward declarations //basic stuff which has to be in .h // dtor template NRVec::~NRVec() { if(!count) return; if(--(*count) <= 0) { if(v) delete[] (v); delete count; } } // detach from a physical vector and make own copy template void NRVec::copyonwrite() { #ifdef DEBUG if(!count) laerror("probably an assignment to undefined vector"); #endif if(*count > 1) { (*count)--; count = new int; *count = 1; T *newv = new T[nn]; memcpy(newv, v, nn*sizeof(T)); v = newv; } } // Asignment template NRVec & NRVec::operator=(const NRVec &rhs) { if (this != &rhs) { if(count) if(--(*count) == 0) { delete[] v; delete count; } v = rhs.v; nn = rhs.nn; count = rhs.count; if(count) (*count)++; } return *this; } // Resize template void NRVec::resize(const int n) { #ifdef DEBUG if(n<=0) laerror("illegal vector dimension"); #endif if(count) if(*count > 1) { (*count)--; count = 0; v = 0; nn = 0; } if(!count) { count = new int; *count = 1; nn = n; v = new T[nn]; return; } // *count = 1 in this branch if (n != nn) { nn = n; delete[] v; v = new T[nn]; } } // assignmet with a physical (deep) copy template NRVec & NRVec::operator|=(const NRVec &rhs) { if (this != &rhs) { #ifdef DEBUG if (!rhs.v) laerror("unallocated rhs in NRVec operator |="); #endif if (count) if (*count > 1) { --(*count); nn = 0; count = 0; v = 0; } if (nn != rhs.nn) { if (v) delete[] (v); nn = rhs.nn; } if(!v) v = new T[nn]; if(!count) count = new int; *count = 1; memcpy(v, rhs.v, nn*sizeof(T)); } return *this; } #endif /* _LA_VEC_H_ */