diff --git a/auxstorage.h b/auxstorage.h index b12eae6..80a642c 100644 --- a/auxstorage.h +++ b/auxstorage.h @@ -1,6 +1,5 @@ #ifndef _AUXSTORAGE_H_ #define _AUXSTORAGE_H_ -#include "laerror.h" #include "vec.h" #include "mat.h" #include "smat.h" @@ -68,6 +67,7 @@ void AuxStorage::get(NRVec &x, const int pos) const { if(recl==0) laerror("get from an empty file in AuxStorage"); if((off64_t)-1 == lseek64(fd,pos*((off64_t)recl),SEEK_SET)) {perror(""); laerror("seek failed in AuxStorage");} +x.copyonwrite(); if((ssize_t)recl!=read(fd,&x[0],recl)) {perror(""); laerror("read failed in AuxStorage");} } @@ -86,6 +86,7 @@ void AuxStorage::get(NRMat &x, const int pos) const { if(recl==0) laerror("get from an empty file in AuxStorage"); if((off64_t)-1 == lseek64(fd,pos*((off64_t)recl),SEEK_SET)) {perror(""); laerror("seek failed in AuxStorage");} +x.copyonwrite(); if((ssize_t)recl!=read(fd,&x(0,0),recl)) {perror(""); laerror("read failed in AuxStorage");} } @@ -104,6 +105,7 @@ void AuxStorage::get(NRSMat &x, const int pos) const { if(recl==0) laerror("get from an empty file in AuxStorage"); if((off64_t)-1 == lseek64(fd,pos*((off64_t)recl),SEEK_SET)) {perror(""); laerror("seek failed in AuxStorage");} +x.copyonwrite(); if((ssize_t)recl!=read(fd,&x(0,0),recl)) {perror(""); laerror("read failed in AuxStorage");} } diff --git a/diis.h b/diis.h index 6feb1bf..d6f09b8 100644 --- a/diis.h +++ b/diis.h @@ -1,4 +1,4 @@ -//DIIS convergence acceleration +//DIIS convergence acceleration according to Pulay: Chem. Phys. Lett. 73, 393 (1980); J. Comp. Chem. 3,556 (1982) #ifndef _DIIS_H_ #define _DIIS_H_ #include "vec.h" diff --git a/la.h b/la.h index d787b11..e0aab9a 100644 --- a/la.h +++ b/la.h @@ -1,9 +1,7 @@ #ifndef _LA_H_ #define _LA_H_ -#ifdef _GLIBCPP_NO_TEMPLATE_EXPORT -# define export -#endif +//this should be the single include file for the end user #include "vec.h" #include "smat.h" diff --git a/la_traits.h b/la_traits.h index ef5c710..7290515 100644 --- a/la_traits.h +++ b/la_traits.h @@ -1,9 +1,26 @@ //////////////////////////////////////////////////////////////////////////// -//LA traits classes +//LA traits classes and generally needed includes #ifndef _LA_TRAITS_INCL #define _LA_TRAITS_INCL + + +using namespace std; +#include +#include +#include +#include +#include "laerror.h" + +extern "C" { +#include "cblas.h" +} + +#ifdef _GLIBCPP_NO_TEMPLATE_EXPORT +# define export +#endif + //forward declarations template class NRVec; template class NRMat; diff --git a/mat.cc b/mat.cc index 2269d06..a18f09c 100644 --- a/mat.cc +++ b/mat.cc @@ -187,13 +187,14 @@ const NRVec NRMat::rsum() const // transpose Mat template -NRMat & NRMat::transposeme() +NRMat & NRMat::transposeme(int n) { +if(n==0) n=nn; #ifdef DEBUG - if (nn != mm) laerror("transpose of non-square Mat"); + if (n==nn && nn != mm || n>mm || n>nn) laerror("transpose of non-square Mat"); #endif copyonwrite(); - for(int i=1; i::fscanf(FILE *f, const char *format) - - - - - - - - - - - - - - - - - - - - /* * BLAS specializations for double and complex @@ -602,29 +583,6 @@ NRMat< complex >::operator*(const NRSMat< complex > &rhs) const return result; } -// Mat * Vec -const NRVec -NRMat::operator*(const NRVec &vec) const -{ -#ifdef DEBUG - if(mm != vec.size()) laerror("incompatible sizes in Mat*Vec"); -#endif - NRVec result(nn); - cblas_dgemv(CblasRowMajor, CblasNoTrans, nn, mm, 1.0, (*this)[0], - mm, &vec[0], 1, 0.0, &result[0], 1); - return result; -} -const NRVec< complex > -NRMat< complex >::operator*(const NRVec< complex > &vec) const -{ -#ifdef DEBUG - if(mm != vec.size()) laerror("incompatible sizes in Mat*Vec"); -#endif - NRVec< complex > result(nn); - cblas_zgemv(CblasRowMajor, CblasNoTrans, nn, mm, (void *)&CONE, (void *)(*this)[0], - mm, (void *)&vec[0], 1, (void *)&CZERO, (void *)&result[0], 1); - return result; -} // sum of rows const NRVec NRMat::rsum() const diff --git a/mat.h b/mat.h index c2e2275..9e587ea 100644 --- a/mat.h +++ b/mat.h @@ -1,8 +1,6 @@ #ifndef _LA_MAT_H_ #define _LA_MAT_H_ -#include "vec.h" -#include "smat.h" #include "la_traits.h" template @@ -55,7 +53,7 @@ public: inline const NRMat operator-(const NRMat &rhs) const; inline const NRMat operator+(const NRSMat &rhs) const; inline const NRMat operator-(const NRSMat &rhs) const; - const T dot(const NRMat &rhs) const; // scalar product of Mat.Mat + const T dot(const NRMat &rhs) const; // scalar product of Mat.Mat//@@@for complex do conjugate const NRMat operator*(const NRMat &rhs) const; // Mat * Mat const NRMat oplus(const NRMat &rhs) const; //direct sum const NRMat otimes(const NRMat &rhs) const; //direct product @@ -64,7 +62,7 @@ public: const NRMat operator*(const NRSMat &rhs) const; // Mat * Smat const NRMat operator&(const NRMat &rhs) const; // direct sum const NRMat operator|(const NRMat &rhs) const; // direct product - const NRVec operator*(const NRVec &rhs) const; // Mat * Vec + const NRVec operator*(const NRVec &rhs) const {NRVec result(nn); result.gemv((T)0,*this,'n',(T)1,rhs); return result;}; // Mat * Vec const NRVec rsum() const; //sum of rows const NRVec csum() const; //sum of columns void diagonalof(NRVec &, const bool divide=0) const; //get diagonal @@ -74,13 +72,14 @@ public: inline const T& operator()(const int i, const int j) const; inline int nrows() const; inline int ncols() const; + inline int size() const; void get(int fd, bool dimensions=1); void put(int fd, bool dimensions=1) const; void copyonwrite(); void resize(const int n, const int m); inline operator T*(); //get a pointer to the data inline operator const T*() const; - NRMat & transposeme(); // square matrices only + NRMat & transposeme(int n=0); // square matrices only NRMat & conjugateme(); // square matrices only const NRMat transpose(bool conj=false) const; const NRMat conjugate() const; @@ -103,6 +102,7 @@ public: NRMat & operator+=(const SparseMat &rhs); NRMat & operator-=(const SparseMat &rhs); inline void simplify() {}; //just for compatibility with sparse ones + bool issymmetric() const {return 0;}; //Strassen's multiplication (better than n^3, analogous syntax to gemm) void strassen(const T beta, const NRMat &a, const char transa, const NRMat &b, const char transb, const T alpha);//this := alpha*op( A )*op( B ) + beta*this @@ -110,6 +110,11 @@ public: }; +//due to mutual includes this has to be after full class declaration +#include "vec.h" +#include "smat.h" +#include "sparsemat.h" + // ctors template NRMat::NRMat(const int n, const int m) : nn(n), mm(m), count(new int) @@ -294,6 +299,12 @@ inline int NRMat::ncols() const return mm; } +template +inline int NRMat::size() const +{ + return nn*mm; +} + // reference pointer to Mat template inline NRMat::operator T* () diff --git a/nonclass.cc b/nonclass.cc index 41011d4..7829e14 100644 --- a/nonclass.cc +++ b/nonclass.cc @@ -2,7 +2,11 @@ extern "C" { #include "atlas_enum.h" #include "clapack.h" } -#include "la.h" +#include "vec.h" +#include "smat.h" +#include "mat.h" +#include "nonclass.h" + #ifdef FORTRAN_ #define FORNAME(x) x##_ diff --git a/nonclass.h b/nonclass.h index c2b69a0..f1efb40 100644 --- a/nonclass.h +++ b/nonclass.h @@ -74,11 +74,6 @@ extern void gdiagonalize(NRMat &a, NRVec< complex > &w, extern NRMat matrixfunction(NRSMat a, double (*f) (double)); extern NRMat matrixfunction(NRMat a, complex (*f)(const complex &),const bool adjust=0); - -////////////////////////////// -//other than lapack functions/ -////////////////////////////// - //functions on matrices inline NRMat sqrt(const NRSMat &a) { return matrixfunction(a,&sqrt); } inline NRMat log(const NRSMat &a) { return matrixfunction(a,&log); } diff --git a/smat.cc b/smat.cc index f165426..dc02d40 100644 --- a/smat.cc +++ b/smat.cc @@ -248,27 +248,6 @@ NRSMat< complex >::dot(const NRSMat< complex > &rhs) const return dot; } -// x = S * x -const NRVec NRSMat::operator*(const NRVec &rhs) const -{ -#ifdef DEBUG - if (nn!=rhs.size()) laerror("incompatible dimension in Smat*Vec"); -#endif - NRVec result(nn); - cblas_dspmv(CblasRowMajor, CblasLower, nn, 1.0, v, rhs, 1, 0.0, result, 1); - return result; -} -const NRVec< complex > -NRSMat< complex >::operator*(const NRVec< complex > &rhs) const -{ -#ifdef DEBUG - if (nn!=rhs.size()) laerror("incompatible dimension in Smat*Vec"); -#endif - NRVec< complex > result(nn); - cblas_zhpmv(CblasRowMajor, CblasLower, nn, (void *)(&CONE), (void *)v, - (const void *)rhs, 1, (void *)(&CZERO), (void *)result, 1); - return result; -} // norm of the matrix const double NRSMat::norm(const double scalar) const @@ -347,18 +326,6 @@ istream& operator>>(istream &s, NRSMat &x) return s; } -//not implemented yet -const NRVec NRSMat::operator*(NRVec const&rhs) const -{ -laerror("NRSMat::operator*(NRVec const&) not implemented yet"); -return rhs; -} - -const NRVec NRSMat::operator*(NRVec const&rhs) const -{ -laerror("NRSMat::operator*(NRVec const&) not implemented yet"); -return rhs; -} diff --git a/smat.h b/smat.h index 1f38d91..913959d 100644 --- a/smat.h +++ b/smat.h @@ -1,8 +1,6 @@ #ifndef _LA_SMAT_H_ #define _LA_SMAT_H_ -#include "vec.h" -#include "mat.h" #include "la_traits.h" #define NN2 (nn*(nn+1)/2) @@ -44,8 +42,8 @@ public: inline const NRMat operator-(const NRMat &rhs) const; const NRMat operator*(const NRSMat &rhs) const; // SMat*SMat const NRMat operator*(const NRMat &rhs) const; // SMat*Mat - const T dot(const NRSMat &rhs) const; // Smat.Smat - const NRVec operator*(const NRVec &rhs) const; + const T dot(const NRSMat &rhs) const; // Smat.Smat//@@@for complex do conjugate + const NRVec operator*(const NRVec &rhs) const {NRVec result(nn); result.gemv((T)0,*this,'n',(T)1,rhs); return result;}; // Mat * Vec void diagonalof(NRVec &, const bool divide=0) const; //get diagonal inline const T& operator[](const int ij) const; inline T& operator[](const int ij); @@ -53,6 +51,7 @@ public: inline T& operator()(const int i, const int j); inline int nrows() const; inline int ncols() const; + inline int size() const; const double norm(const T scalar=(T)0) const; void axpy(const T alpha, const NRSMat &x); // this+= a*x inline const T amax() const; @@ -69,9 +68,15 @@ public: //members concerning sparse matrix explicit NRSMat(const SparseMat &rhs); // dense from sparse inline void simplify() {}; //just for compatibility with sparse ones + bool issymmetric() const {return 1;} }; -// INLINES +//due to mutual includes this has to be after full class declaration +#include "vec.h" +#include "mat.h" +#include "sparsemat.h" + + // ctors template inline NRSMat::NRSMat(const int n) : nn(n), v(new T[NN2]), @@ -293,6 +298,13 @@ inline int NRSMat::ncols() const return nn; } +template +inline int NRSMat::size() const +{ + return NN2; +} + + // max value inline const double NRSMat::amax() const { diff --git a/sparsemat.cc b/sparsemat.cc index f4487cf..787e22f 100644 --- a/sparsemat.cc +++ b/sparsemat.cc @@ -1,7 +1,5 @@ #include #include -#include -#include #include #include #include @@ -14,10 +12,6 @@ template SparseMat; template SparseMat >; -#ifdef _GLIBCPP_NO_TEMPLATE_EXPORT -# define export -#endif - export template ostream& operator<<(ostream &s, const SparseMat &x) @@ -403,7 +397,7 @@ export template void SparseMat::incsize(const SPMatindex n, const SPMatindex m) { if(symmetric && n!=m) laerror("unsymmetric size increment of a symmetric sparsemat"); - if(!count && nn==0 && mm==0) count=new int(1); + if(!count ) count=new int(1); copyonwrite();//this errors if !count unsort(); nn+=n; @@ -883,34 +877,8 @@ else } -//multiplication with dense vector from both sides -template -const NRVec SparseMat::multiplyvector(const NRVec &vec, const bool transp) const -{ -if(transp && nn!=(SPMatindex)vec.size() || !transp && mm!=(SPMatindex)vec.size()) laerror("incompatible sizes in sparsemat*vector"); -NRVec result(transp?mm:nn); -result.gemv((T)0, *this, transp?'t':'n', (T)1., vec); -return result; -} -template -const NRVec NRVec::operator*(const SparseMat &mat) const -{ -if(mat.nrows()!= (SPMatindex)size()) laerror("incompatible sizes in vector*sparsemat"); -NRVec result((T)0,mat.ncols()); -matel *l=mat.getlist(); -bool symmetric=mat.issymmetric(); -while(l) - { - result.v[l->col]+= l->elem*v[l->row]; - if(symmetric&&l->row!=l->col) result.v[l->row]+= l->elem*v[l->col]; - l=l->next; - } -return result; - -} - template const T SparseMat::trace() const { @@ -1249,7 +1217,6 @@ template SparseMat & SparseMat::operator-=(const T a); \ template NRMat::NRMat(const SparseMat &rhs); \ template NRSMat::NRSMat(const SparseMat &rhs); \ template NRVec::NRVec(const SparseMat &rhs); \ -template const NRVec SparseMat::operator*(const NRVec &vec) const; \ template const NRVec NRVec::operator*(const SparseMat &mat) const; \ template SparseMat & SparseMat::join(SparseMat &rhs); \ template const T SparseMat::trace() const; \ @@ -1263,7 +1230,7 @@ template void NRVec::gemv(const T beta, const SparseMat &a, const char tra INSTANTIZE(double) +INSTANTIZE(complex) //some functions are not OK for hermitean matrices, needs a revision!!! -// some functions are not OK for hermitean! INSTANTIZE(complex) #endif diff --git a/sparsemat.h b/sparsemat.h index 8d3a10f..7e033d8 100644 --- a/sparsemat.h +++ b/sparsemat.h @@ -1,14 +1,13 @@ #ifndef _SPARSEMAT_H_ #define _SPARSEMAT_H_ -//for vectors and dense matrices we shall need -#include "la.h" +#include "la_traits.h" -template +template inline const T MAX(const T &a, const T &b) {return b > a ? (b) : (a);} -template +template inline void SWAP(T &a, T &b) {T dum=a; a=b; b=dum;} @@ -21,7 +20,7 @@ typedef unsigned int SPMatindex; typedef int SPMatindexdiff; //more clear would be to use traits //element of a linked list -template +template struct matel { T elem; @@ -31,7 +30,7 @@ struct matel }; -template +template class SparseMat { protected: SPMatindex nn; @@ -86,8 +85,7 @@ public: inline const SparseMat operator*(const T &rhs) const {return SparseMat(*this) *= rhs;} inline const SparseMat operator+(const SparseMat &rhs) const {return SparseMat(*this) += rhs;} //must not be symmetric+general inline const SparseMat operator-(const SparseMat &rhs) const {return SparseMat(*this) -= rhs;} //must not be symmetric+general - const NRVec multiplyvector(const NRVec &rhs, const bool transp=0) const; //sparse matrix * dense vector optionally transposed - inline const NRVec operator*(const NRVec &rhs) const {return multiplyvector(rhs);} //sparse matrix * dense vector + inline const NRVec operator*(const NRVec &rhs) const; // Mat * Vec void diagonalof(NRVec &, const bool divide=0) const; //get diagonal const SparseMat operator*(const SparseMat &rhs) const; SparseMat & oplusequal(const SparseMat &rhs); //direct sum @@ -128,14 +126,23 @@ public: void addsafe(const SPMatindex n, const SPMatindex m, const T elem); }; -template +//due to mutual includes this has to be after full class declaration +#include "vec.h" +#include "smat.h" +#include "mat.h" + +template +inline const NRVec SparseMat::operator*(const NRVec &rhs) const +{NRVec result(nn); result.gemv((T)0,*this,'n',(T)1,rhs); return result;}; + +template extern istream& operator>>(istream &s, SparseMat &x); -template +template extern ostream& operator<<(ostream &s, const SparseMat &x); //destructor -template +template SparseMat::~SparseMat() { unsort(); @@ -148,7 +155,7 @@ SparseMat::~SparseMat() } //copy constructor (sort arrays are not going to be copied) -template +template SparseMat::SparseMat(const SparseMat &rhs) { #ifdef debug @@ -164,7 +171,7 @@ if(! &rhs) laerror("SparseMat copy constructor with NULL argument"); nonzero=0; } -template +template const SparseMat SparseMat::transpose() const { if(list&&!count) laerror("some inconsistency in SparseMat transpose"); @@ -195,7 +202,7 @@ return result; -template +template inline const SparseMat commutator ( const SparseMat &x, const SparseMat &y, const bool trx=0, const bool tryy=0) { SparseMat r; @@ -204,7 +211,7 @@ r.gemm((T)1,y,tryy?'t':'n',x,trx?'t':'n',(T)-1); //saves a temporary and simplif return r; } -template +template inline const SparseMat anticommutator ( const SparseMat &x, const SparseMat &y, const bool trx=0, const bool tryy=0) { SparseMat r; @@ -215,7 +222,7 @@ return r; //add sparse to dense -template +template NRMat & NRMat::operator+=(const SparseMat &rhs) { if((unsigned int)nn!=rhs.nrows()||(unsigned int)mm!=rhs.ncols()) laerror("incompatible matrices in +="); diff --git a/strassen.cc b/strassen.cc index 5179206..caa6044 100644 --- a/strassen.cc +++ b/strassen.cc @@ -1,4 +1,6 @@ -#include "la.h" + +#include "mat.h" + /*Strassen algorithm*/ // called routine is fortran-compatible extern "C" void fmm(const char c_transa,const char c_transb,const int m,const int n,const int k,const double alpha, diff --git a/vec.cc b/vec.cc index e5d187b..5fb81b1 100644 --- a/vec.cc +++ b/vec.cc @@ -15,6 +15,10 @@ extern ssize_t write(int, const void *, size_t); #define INSTANTIZE(T) \ template ostream & operator<<(ostream &s, const NRVec< T > &x); \ template istream & operator>>(istream &s, NRVec< T > &x); \ +template void NRVec::put(int fd, bool dim) const; \ +template void NRVec::get(int fd, bool dim); \ + + INSTANTIZE(double) INSTANTIZE(complex) @@ -26,8 +30,10 @@ INSTANTIZE(char) INSTANTIZE(unsigned char) template NRVec; template NRVec >; -template NRVec; template NRVec; +template NRVec; + + /* @@ -228,13 +234,55 @@ NRVec< complex > & NRVec< complex >::normalize() return *this; } -//and for these types it does not make sense to normalize but we have them for linkage +//stubs for linkage NRVec & NRVec::normalize() {laerror("normalize() impossible for integer types"); return *this;} NRVec & NRVec::normalize() {laerror("normalize() impossible for integer types"); return *this;} +void NRVec::gemv(const int beta, + const NRSMat &A, const char trans, + const int alpha, const NRVec &x) +{ +laerror("not yet implemented"); +} + +void NRVec::gemv(const char beta, + const NRSMat &A, const char trans, + const char alpha, const NRVec &x) +{ +laerror("not yet implemented"); +} + +void NRVec::gemv(const int beta, + const NRMat &A, const char trans, + const int alpha, const NRVec &x) +{ +laerror("not yet implemented"); +} + +void NRVec::gemv(const char beta, + const NRMat &A, const char trans, + const char alpha, const NRVec &x) +{ +laerror("not yet implemented"); +} + +void NRVec::gemv(const int beta, + const SparseMat &A, const char trans, + const int alpha, const NRVec &x) +{ +laerror("not yet implemented"); +} + +void NRVec::gemv(const char beta, + const SparseMat &A, const char trans, + const char alpha, const NRVec &x) +{ +laerror("not yet implemented"); +} -// gemv call + +// gemv calls void NRVec::gemv(const double beta, const NRMat &A, const char trans, const double alpha, const NRVec &x) { @@ -243,8 +291,9 @@ void NRVec::gemv(const double beta, const NRMat &A, laerror("incompatible sizes in gemv A*x"); #endif cblas_dgemv(CblasRowMajor, (trans=='n' ? CblasNoTrans:CblasTrans), - A.nrows(), A.ncols(), alpha, A[0], A.ncols(), x.v, 1, beta, v, 1); + A.nrows(), A.ncols(), alpha, A, A.ncols(), x.v, 1, beta, v, 1); } + void NRVec< complex >::gemv(const complex beta, const NRMat< complex > &A, const char trans, const complex alpha, const NRVec &x) @@ -254,35 +303,38 @@ void NRVec< complex >::gemv(const complex beta, laerror("incompatible sizes in gemv A*x"); #endif cblas_zgemv(CblasRowMajor, (trans=='n' ? CblasNoTrans:CblasTrans), - A.nrows(), A.ncols(), (void *)(&alpha), (void *)A[0], A.ncols(), - (void *)x.v, 1, (void *)(&beta), (void *)v, 1); + A.nrows(), A.ncols(), &alpha, A, A.ncols(), + x.v, 1, &beta, v, 1); } -// Vec * Mat -const NRVec NRVec::operator*(const NRMat &mat) const + +void NRVec::gemv(const double beta, const NRSMat &A, + const char trans, const double alpha, const NRVec &x) { #ifdef DEBUG - if(mat.nrows() != nn) laerror("incompatible sizes in Vec*Mat"); + if (A.ncols()!=x.size()) laerror("incompatible dimension in gemv A*x"); #endif - int n = mat.ncols(); - NRVec result(n); - cblas_dgemv(CblasRowMajor, CblasTrans, nn, n, 1.0, mat[0], n, v, 1, - 0.0, result.v, 1); - return result; + NRVec result(nn); + cblas_dspmv(CblasRowMajor, CblasLower, A.ncols(), alpha, A, x.v, 1, beta, v, 1); } -const NRVec< complex > -NRVec< complex >::operator*(const NRMat< complex > &mat) const + + +void NRVec< complex >::gemv(const complex beta, + const NRSMat< complex > &A, const char trans, + const complex alpha, const NRVec &x) { #ifdef DEBUG - if(mat.nrows() != nn) laerror("incompatible sizes in Vec*Mat"); + if (A.ncols()!=x.size()) laerror("incompatible dimension in gemv"); #endif - int n = mat.ncols(); - NRVec< complex > result(n); - cblas_zgemv(CblasRowMajor, CblasTrans, nn, n, &CONE, mat[0], n, v, 1, - &CZERO, result.v, 1); - return result; + NRVec< complex > result(nn); + cblas_zhpmv(CblasRowMajor, CblasLower, A.ncols(), &alpha, A, + x.v, 1, &beta, v, 1); } + + + + // Direc product Mat = Vec | Vec const NRMat NRVec::operator|(const NRVec &b) const { diff --git a/vec.h b/vec.h index be353c4..9d980bc 100644 --- a/vec.h +++ b/vec.h @@ -1,24 +1,8 @@ #ifndef _LA_VEC_H_ #define _LA_VEC_H_ -#include "laerror.h" -extern "C" { -#include "cblas.h" -} -#include -#include -#include -#include - -using namespace std; - #include "la_traits.h" -template class NRVec; -template class NRSMat; -template class NRMat; -template class SparseMat; - ////////////////////////////////////////////////////////////////////////////// // Forward declarations template void lawritemat(FILE *file,const T *a,int r,int c, @@ -43,9 +27,6 @@ template \ inline const NR##E NR##E::operator X(const NR##E &a) const \ { return NR##E(*this) X##= a; } -#include "smat.h" -#include "mat.h" - // NRVec class template @@ -84,9 +65,14 @@ public: inline const NRVec operator+(const T &a) const; inline const NRVec operator-(const T &a) const; inline const NRVec operator*(const T &a) const; - inline const T operator*(const NRVec &rhs) const; //scalar product -> ddot - inline const NRVec operator*(const NRSMat & S) const; - const NRVec operator*(const NRMat &mat) const; + inline const T operator*(const NRVec &rhs) const; //scalar product -> dot + inline const T dot(const NRVec &rhs) const {return *this * rhs;}; //@@@for complex do conjugate + void gemv(const T beta, const NRMat &a, const char trans, const T alpha, const NRVec &x); + void gemv(const T beta, const NRSMat &a, const char trans /*just for compatibility*/, const T alpha, const NRVec &x); + void gemv(const T beta, const SparseMat &a, const char trans, const T alpha, const NRVec &x); + const NRVec operator*(const NRMat &mat) const {NRVec result(mat.ncols()); result.gemv((T)0,mat,'t',(T)1,*this); return result;}; + const NRVec operator*(const NRSMat &mat) const {NRVec result(mat.ncols()); result.gemv((T)0,mat,'t',(T)1,*this); return result;}; + const NRVec operator*(const SparseMat &mat) const {NRVec result(mat.ncols()); result.gemv((T)0,mat,'t',(T)1,*this); return result;}; const NRMat operator|(const NRVec &rhs) const; inline const T sum() const; //sum of its elements inline const T dot(const T *a, const int stride=1) const; // ddot with a stride-vector @@ -98,8 +84,6 @@ public: ~NRVec(); void axpy(const T alpha, const NRVec &x); // this+= a*x void axpy(const T alpha, const T *x, const int stride=1); // this+= a*x - void gemv(const T beta, const NRMat &a, const char trans, - const T alpha, const NRVec &x); void copyonwrite(); void resize(const int n); void get(int fd, bool dimensions=1); @@ -112,11 +96,14 @@ public: void fscanf(FILE *f, const char *format); //sparse matrix concerning members explicit NRVec(const SparseMat &rhs); // dense from sparse matrix with one of dimensions =1 - const NRVec operator*(const SparseMat &mat) const; //vector*matrix inline void simplify() {}; //just for compatibility with sparse ones - void gemv(const T beta, const SparseMat &a, const char trans, const T alpha, const NRVec &x); }; +//due to mutual includes this has to be after full class declaration +#include "mat.h" +#include "smat.h" +#include "sparsemat.h" + template ostream & operator<<(ostream &s, const NRVec &x); template istream & operator>>(istream &s, NRVec &x); @@ -313,28 +300,36 @@ inline NRVec & NRVec::operator*=(const T &a) inline const double NRVec::operator*(const NRVec &rhs) const { #ifdef DEBUG - if (nn != rhs.nn) laerror("ddot of incompatible vectors"); + if (nn != rhs.nn) laerror("dot of incompatible vectors"); #endif - return cblas_ddot(nn, v, 1, rhs.v, 1); + return cblas_ddot(nn, v, 1, rhs.v, 1); } + + inline const complex NRVec< complex >::operator*(const NRVec< complex > &rhs) const { #ifdef DEBUG - if (nn != rhs.nn) laerror("ddot of incompatible vectors"); + if (nn != rhs.nn) laerror("dot of incompatible vectors"); #endif complex dot; cblas_zdotc_sub(nn, (void *)v, 1, (void *)rhs.v, 1, (void *)(&dot)); return dot; } -// Vec * SMat = SMat * Vec -template -inline const NRVec NRVec::operator*(const NRSMat & S) const +template +inline const T NRVec::operator*(const NRVec &rhs) const { - return S * (*this); +#ifdef DEBUG + if (nn != rhs.nn) laerror("dot of incompatible vectors"); +#endif + T dot = 0; + for(int i=0; i::sum() const {