diff --git a/mat.cc b/mat.cc index 895a04d..1007f0d 100644 --- a/mat.cc +++ b/mat.cc @@ -34,24 +34,6 @@ NRMat & NRMat::operator=(const T &a) -//get diagonal; for compatibility with large matrices do not return newly created object -template -void NRMat::diagonalof(NRVec &r) const -{ -#ifdef DEBUG - if (nn != mm) laerror("diagonalof() non-square matrix"); - if (r.size() != nn) laerror("diagonalof() incompatible vector"); -#endif - -#ifdef MATPTR - for (int i=0; i< nn; i++) r[i] = v[i][i]; -#else - {int i,j; for (i=j=0; j< nn; ++j, i+=nn+1) r[j] = v[i];} -#endif - -} - - // M += a template @@ -757,6 +739,38 @@ const complex NRMat< complex >::trace() const return sum; } + + +//get diagonal; for compatibility with large matrices do not return newly created object +//for non-square get diagonal of A^TA, will be used as preconditioner +void NRMat::diagonalof(NRVec &r) const +{ +#ifdef DEBUG + if (r.size() != nn) laerror("diagonalof() incompatible vector"); +#endif + +if(nn==mm) +{ +#ifdef MATPTR + for (int i=0; i< nn; i++) r[i] = v[i][i]; +#else + {int i,j; for (i=j=0; j< nn; ++j, i+=nn+1) r[j] = v[i];} +#endif +} +else //non-square +{ +for (int i=0; i< mm; i++) +#ifdef MATPTR + r[i] = cblas_ddot(nn,v[0]+i,mm,v[0]+i,mm); +#else + r[i] = cblas_ddot(nn,v+i,mm,v+i,mm); +#endif +} + +} + + + ////////////////////////////////////////////////////////////////////////////// //// forced instantization in the corespoding object file #define INSTANTIZE(T) \ diff --git a/sparsemat.cc b/sparsemat.cc index 80dd2f7..5c2c789 100644 --- a/sparsemat.cc +++ b/sparsemat.cc @@ -490,15 +490,22 @@ template void SparseMat::diagonalof(NRVec &r) const { #ifdef DEBUG -if(nn!=mm) laerror("diagonalof() non-square sparse matrix"); +if((int)mm!=r.size()) laerror("incompatible vector size in diagonalof()"); #endif matel *l=list; r=(T)0; -while(l) - { - if(l->row == l->col) r[l->row]+= l->elem; - l= l->next; - } +if(nn==mm) //square + while(l) + { + if(l->row == l->col) r[l->row]+= l->elem; + l= l->next; + } +else //diagonal of A^TA, assuming it has been simplified (only one entry per position), will be used as preconditioner only anyway + while(l) + { + r[l->col] += l->elem*l->elem *(l->col!=l->row && symmetric?2.:1.); + l= l->next; + } }