diff --git a/mat.cc b/mat.cc index 0bc7076..233252c 100644 --- a/mat.cc +++ b/mat.cc @@ -1109,9 +1109,7 @@ const complex NRMat< complex >::trace() const template<> const double * NRMat::diagonalof(NRVec &r, const bool divide, bool cache) const { -#ifdef DEBUG if (r.size() != nn) laerror("diagonalof() incompatible vector"); -#endif double a; @@ -1145,6 +1143,23 @@ return divide?NULL:&r[0]; } +//set diagonal +template<> +void NRMat::diagonalset(const NRVec &r) +{ + if (r.size() != nn) laerror("diagonalset() incompatible vector"); + if(nn!=mm) laerror("diagonalset only for square matrix"); + +copyonwrite(); + +#ifdef MATPTR +for (int i=0; i< nn; i++) v[i][i] = r[i]; +#else +{int i,j; for (i=j=0; j< nn; ++j, i+=nn+1) v[i] = r[j];} +#endif +} + + diff --git a/mat.h b/mat.h index 8c02a92..a1a1e9f 100644 --- a/mat.h +++ b/mat.h @@ -92,6 +92,7 @@ public: const NRVec row(const int i, int l= -1) const; //row of, efficient const NRVec column(const int j, int l= -1) const {if(l<0) l=nn; NRVec r(l); for(int i=0; i &, const bool divide=0, bool cache=false) const; //get diagonal + void diagonalset(const NRVec &); //set diagonal elements void gemv(const T beta, NRVec &r, const char trans, const T alpha, const NRVec &x) const {r.gemv(beta,*this,trans,alpha,x);}; inline T* operator[](const int i); //subscripting: pointer to row i inline const T* operator[](const int i) const;