diff --git a/la_traits.h b/la_traits.h index 6191724..0c221c5 100644 --- a/la_traits.h +++ b/la_traits.h @@ -45,6 +45,7 @@ class isscalar { //specializations #define SCALAR(X) \ +template<>\ class isscalar {typedef scalar_true scalar_type;}; //declare what is scalar diff --git a/mat.h b/mat.h index eae57dc..548d053 100644 --- a/mat.h +++ b/mat.h @@ -24,8 +24,10 @@ public: NRMat(const T *a, const int n, const int m); inline NRMat(const NRMat &rhs); explicit NRMat(const NRSMat &rhs); -#ifndef MATPTR - NRMat(const NRVec &rhs, const int n, const int m); +#ifdef MATPTR + explicit NRMat(const NRVec &rhs, const int n, const int m) :NRMat(&rhs[0][0],n,m) {}; +#else + explicit NRMat(const NRVec &rhs, const int n, const int m); #endif ~NRMat(); #ifdef MATPTR @@ -329,6 +331,7 @@ inline NRMat::operator const T* () const } // max element of Mat +template<> inline const double NRMat::amax() const { #ifdef MATPTR @@ -337,6 +340,7 @@ inline const double NRMat::amax() const return v[cblas_idamax(nn*mm, v, 1)]; #endif } +template<> inline const complex NRMat< complex >::amax() const { #ifdef MATPTR diff --git a/smat.h b/smat.h index ffd0708..5c139cc 100644 --- a/smat.h +++ b/smat.h @@ -14,7 +14,7 @@ public: friend class NRVec; friend class NRMat; - inline NRSMat::NRSMat() : nn(0),v(0),count(0) {}; + inline NRSMat() : nn(0),v(0),count(0) {}; inline explicit NRSMat(const int n); // Zero-based array inline NRSMat(const T &a, const int n); //Initialize to constant inline NRSMat(const T *a, const int n); // Initialize to array @@ -121,12 +121,15 @@ NRSMat::NRSMat(const NRVec &rhs, const int n) // type conversion } // S *= a +template<> inline NRSMat & NRSMat::operator*=(const double & a) { copyonwrite(); cblas_dscal(NN2, a, v, 1); return *this; } + +template<> inline NRSMat< complex > & NRSMat< complex >::operator*=(const complex & a) { @@ -163,6 +166,7 @@ inline NRSMat & NRSMat::operator-=(const T &a) } // S += S +template<> inline NRSMat & NRSMat::operator+=(const NRSMat & rhs) { @@ -173,6 +177,8 @@ NRSMat::operator+=(const NRSMat & rhs) cblas_daxpy(NN2, 1.0, rhs.v, 1, v, 1); return *this; } + +template<> NRSMat< complex > & NRSMat< complex >::operator+=(const NRSMat< complex > & rhs) { @@ -197,6 +203,7 @@ inline NRSMat & NRSMat::operator+=(const NRSMat & rhs) // S -= S +template<> inline NRSMat & NRSMat::operator-=(const NRSMat & rhs) { @@ -207,6 +214,8 @@ NRSMat::operator-=(const NRSMat & rhs) cblas_daxpy(NN2, -1.0, rhs.v, 1, v, 1); return *this; } + +template<> inline NRSMat< complex > & NRSMat< complex >::operator-=(const NRSMat< complex > & rhs) { @@ -306,10 +315,12 @@ inline int NRSMat::size() const // max value +template<> inline const double NRSMat::amax() const { return v[cblas_idamax(NN2, v, 1)]; } +template<> inline const complex NRSMat< complex >::amax() const { return v[cblas_izamax(NN2, (void *)v, 1)]; diff --git a/vec.cc b/vec.cc index 377938c..c115c04 100644 --- a/vec.cc +++ b/vec.cc @@ -4,6 +4,7 @@ #include #include #include +#include extern "C" { extern ssize_t read(int, void *, size_t); extern ssize_t write(int, const void *, size_t); @@ -171,6 +172,7 @@ return nn void NRVec::axpy(const double alpha, const NRVec &x) { #ifdef DEBUG @@ -181,6 +183,7 @@ void NRVec::axpy(const double alpha, const NRVec &x) } // axpy call for T = complex (not strided) +template<> void NRVec< complex >::axpy(const complex alpha, const NRVec< complex > &x) { @@ -192,6 +195,7 @@ void NRVec< complex >::axpy(const complex alpha, } // axpy call for T = double (strided) +template<> void NRVec::axpy(const double alpha, const double *x, const int stride) { copyonwrite(); @@ -199,6 +203,7 @@ void NRVec::axpy(const double alpha, const double *x, const int stride) } // axpy call for T = complex (strided) +template<> void NRVec< complex >::axpy(const complex alpha, const complex *x, const int stride) { @@ -207,6 +212,7 @@ void NRVec< complex >::axpy(const complex alpha, } // unary minus +template<> const NRVec NRVec::operator-() const { NRVec result(*this); @@ -214,6 +220,8 @@ const NRVec NRVec::operator-() const cblas_dscal(nn, -1.0, result.v, 1); return result; } + +template<> const NRVec< complex > NRVec< complex >::operator-() const { @@ -236,6 +244,7 @@ NRVec & NRVec::operator=(const T &a) } // Normalization of NRVec +template<> NRVec & NRVec::normalize() { double tmp; @@ -251,6 +260,7 @@ NRVec & NRVec::normalize() } // Normalization of NRVec< complex > +template<> NRVec< complex > & NRVec< complex >::normalize() { complex tmp; @@ -265,9 +275,13 @@ NRVec< complex > & NRVec< complex >::normalize() } //stubs for linkage +template<> NRVec & NRVec::normalize() {laerror("normalize() impossible for integer types"); return *this;} +template<> NRVec & NRVec::normalize() {laerror("normalize() impossible for integer types"); return *this;} +template<> NRVec & NRVec::normalize() {laerror("normalize() impossible for integer types"); return *this;} +template<> void NRVec::gemv(const int beta, const NRSMat &A, const char trans, const int alpha, const NRVec &x) @@ -275,6 +289,7 @@ void NRVec::gemv(const int beta, laerror("not yet implemented"); } +template<> void NRVec::gemv(const short beta, const NRSMat &A, const char trans, const short alpha, const NRVec &x) @@ -283,6 +298,7 @@ laerror("not yet implemented"); } +template<> void NRVec::gemv(const char beta, const NRSMat &A, const char trans, const char alpha, const NRVec &x) @@ -290,6 +306,7 @@ void NRVec::gemv(const char beta, laerror("not yet implemented"); } +template<> void NRVec::gemv(const int beta, const NRMat &A, const char trans, const int alpha, const NRVec &x) @@ -297,6 +314,7 @@ void NRVec::gemv(const int beta, laerror("not yet implemented"); } +template<> void NRVec::gemv(const short beta, const NRMat &A, const char trans, const short alpha, const NRVec &x) @@ -305,6 +323,7 @@ laerror("not yet implemented"); } +template<> void NRVec::gemv(const char beta, const NRMat &A, const char trans, const char alpha, const NRVec &x) @@ -312,6 +331,7 @@ void NRVec::gemv(const char beta, laerror("not yet implemented"); } +template<> void NRVec::gemv(const int beta, const SparseMat &A, const char trans, const int alpha, const NRVec &x) @@ -319,6 +339,7 @@ void NRVec::gemv(const int beta, laerror("not yet implemented"); } +template<> void NRVec::gemv(const short beta, const SparseMat &A, const char trans, const short alpha, const NRVec &x) @@ -327,6 +348,7 @@ laerror("not yet implemented"); } +template<> void NRVec::gemv(const char beta, const SparseMat &A, const char trans, const char alpha, const NRVec &x) @@ -338,6 +360,7 @@ laerror("not yet implemented"); // gemv calls +template<> void NRVec::gemv(const double beta, const NRMat &A, const char trans, const double alpha, const NRVec &x) { @@ -349,6 +372,7 @@ void NRVec::gemv(const double beta, const NRMat &A, A.nrows(), A.ncols(), alpha, A, A.ncols(), x.v, 1, beta, v, 1); } +template<> void NRVec< complex >::gemv(const complex beta, const NRMat< complex > &A, const char trans, const complex alpha, const NRVec &x) @@ -363,6 +387,7 @@ void NRVec< complex >::gemv(const complex beta, } +template<> void NRVec::gemv(const double beta, const NRSMat &A, const char trans, const double alpha, const NRVec &x) { @@ -374,6 +399,7 @@ void NRVec::gemv(const double beta, const NRSMat &A, } +template<> void NRVec< complex >::gemv(const complex beta, const NRSMat< complex > &A, const char trans, const complex alpha, const NRVec &x) @@ -391,12 +417,14 @@ void NRVec< complex >::gemv(const complex beta, // Direc product Mat = Vec | Vec +template<> const NRMat NRVec::operator|(const NRVec &b) const { NRMat result(0.,nn,b.nn); cblas_dger(CblasRowMajor, nn, b.nn, 1., v, 1, b.v, 1, result, b.nn); return result; } +template<> const NRMat< complex > NRVec< complex >::operator|(const NRVec< complex > &b) const { diff --git a/vec.h b/vec.h index ef306aa..8e59b97 100644 --- a/vec.h +++ b/vec.h @@ -45,7 +45,9 @@ public: inline NRVec(const T *a, const int n); inline NRVec(const NRVec &rhs); inline explicit NRVec(const NRSMat & S); -#ifndef MATPTR +#ifdef MATPTR + explicit NRVec(const NRMat &rhs) : NRVec(&rhs[0][0],rhs.nrows()*rhs.ncols()) {}; +#else explicit NRVec(const NRMat &rhs); #endif NRVec & operator=(const NRVec &rhs); @@ -152,6 +154,7 @@ inline NRVec::NRVec(const NRSMat &rhs) } // x += a +template<> inline NRVec & NRVec::operator+=(const double &a) { copyonwrite(); @@ -159,6 +162,7 @@ inline NRVec & NRVec::operator+=(const double &a) return *this; } +template<> inline NRVec< complex > & NRVec< complex >::operator+=(const complex &a) { @@ -179,12 +183,15 @@ inline NRVec & NRVec::operator+=(const T &a) // x -= a +template<> inline NRVec & NRVec::operator-=(const double &a) { copyonwrite(); cblas_daxpy(nn, 1.0, &a, 0, v, 1); return *this; } + +template<> inline NRVec< complex > & NRVec< complex >::operator-=(const complex &a) { @@ -205,6 +212,7 @@ inline NRVec & NRVec::operator-=(const T &a) // x += x +template<> inline NRVec & NRVec::operator+=(const NRVec &rhs) { #ifdef DEBUG @@ -214,6 +222,8 @@ inline NRVec & NRVec::operator+=(const NRVec &rhs) cblas_daxpy(nn, 1.0, rhs.v, 1, v, 1); return *this; } + +template<> inline NRVec< complex > & NRVec< complex >::operator+=(const NRVec< complex > &rhs) { @@ -240,6 +250,7 @@ inline NRVec & NRVec::operator+=(const NRVec &rhs) // x -= x +template<> inline NRVec & NRVec::operator-=(const NRVec &rhs) { #ifdef DEBUG @@ -249,6 +260,8 @@ inline NRVec & NRVec::operator-=(const NRVec &rhs) cblas_daxpy(nn, -1.0, rhs.v, 1, v, 1); return *this; } + +template<> inline NRVec< complex > & NRVec< complex >::operator-=(const NRVec< complex > &rhs) { @@ -275,12 +288,15 @@ inline NRVec & NRVec::operator-=(const NRVec &rhs) // x *= a +template<> inline NRVec & NRVec::operator*=(const double &a) { copyonwrite(); cblas_dscal(nn, a, v, 1); return *this; } + +template<> inline NRVec< complex > & NRVec< complex >::operator*=(const complex &a) { @@ -301,6 +317,7 @@ inline NRVec & NRVec::operator*=(const T &a) // scalar product x.y +template<> inline const double NRVec::operator*(const NRVec &rhs) const { #ifdef DEBUG @@ -310,6 +327,7 @@ inline const double NRVec::operator*(const NRVec &rhs) const } +template<> inline const complex NRVec< complex >::operator*(const NRVec< complex > &rhs) const { @@ -335,10 +353,12 @@ inline const T NRVec::operator*(const NRVec &rhs) const // Sum of elements +template<> inline const double NRVec::sum() const { return cblas_dasum(nn, v, 1); } +template<> inline const complex NRVec< complex >::sum() const { @@ -348,10 +368,12 @@ NRVec< complex >::sum() const } // Dot product: x * y +template<> inline const double NRVec::dot(const double *y, const int stride) const { return cblas_ddot(nn, y, stride, v, 1); } +template<> inline const complex NRVec< complex >::dot(const complex *y, const int stride) const { @@ -407,20 +429,24 @@ inline NRVec::operator const T*() const } // return norm of the Vec +template<> inline const double NRVec::norm() const { return cblas_dnrm2(nn, v, 1); } +template<> inline const double NRVec< complex >::norm() const { return cblas_dznrm2(nn, (void *)v, 1); } // Max element of the array +template<> inline const double NRVec::amax() const { return v[cblas_idamax(nn, v, 1)]; } +template<> inline const complex NRVec< complex >::amax() const { return v[cblas_izamax(nn, (void *)v, 1)];