From 8f037f7ccf67ee22b6b199ddfb2178227cde2cc6 Mon Sep 17 00:00:00 2001 From: Jiri Pittner Date: Tue, 4 Mar 2025 16:46:59 +0100 Subject: [PATCH] NRSMat sum() --- smat.h | 11 +++++++++++ vec.cc | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ vec.h | 6 +----- 3 files changed, 65 insertions(+), 5 deletions(-) diff --git a/smat.h b/smat.h index 6f563c7..085220a 100644 --- a/smat.h +++ b/smat.h @@ -169,6 +169,17 @@ public: inline const T amax() const; inline const T amin() const; + //! sum all matrix elements + const T sum(bool onlytriangle=false) const + { + T s = NRVec(*this).sum(); + if(onlytriangle) return s; + s*=2; + for(int i=0; i >& NRVec >::conjugateme() { } +/***************************************************************************//** + * sum up the elements of current vector of general type T + * @return sum + ******************************************************************************/ +template +const T NRVec::sum() const { + NOT_GPU(*this); + T sum; + + sum = (T)0; + for(int i=0; i +const double NRVec::sum() const { + double result=0; +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_daxpy(nn, 1.0, v, 1, &result, 0); +#ifdef CUDALA + }else{ + laerror("not implemented"); + } + } +#endif + return result; +} + +/***************************************************************************//** + * sum up the all of the current double-precision complex vector + * @return sum + ******************************************************************************/ +template <> +const std::complex NRVec >::sum() const { + std::complex result=0; +#ifdef CUDALA + if(location == cpu){ +#endif + cblas_zaxpy(nn, &CONE, v, 1, &result, 0); +#ifdef CUDALA + }else{ + laerror("not implemented"); + } + } +#endif + return result; +} diff --git a/vec.h b/vec.h index 4f248aa..6b55b6c 100644 --- a/vec.h +++ b/vec.h @@ -365,11 +365,7 @@ public: const NRVec otimes2vec(const NRVec &rhs, const bool conjugate = false, const T &scale = 1) const; //! compute the sum of the vector elements - inline const T sum() const { - T sum(v[0]); - for(register int i=1; i