From df9ac6894b09fdaa188a1447b76e2c5f1f3f73f8 Mon Sep 17 00:00:00 2001 From: jiri Date: Thu, 25 Feb 2010 20:47:01 +0000 Subject: [PATCH] *** empty log message *** --- mat.h | 3 + nonclass.cc | 292 ++++++++++++++++++++++++++++++++++++++++++---------- nonclass.h | 73 ++++++------- vec.h | 1 + 4 files changed, 274 insertions(+), 95 deletions(-) diff --git a/mat.h b/mat.h index 3559881..14126de 100644 --- a/mat.h +++ b/mat.h @@ -127,6 +127,9 @@ public: void axpy(const T alpha, const NRMat &x); // this += a*x inline const T amax() const; const T trace() const; + NRMat & SwapRows(); + NRMat & SwapCols(); + NRMat & SwapRowsCols(); //members concerning sparse matrix SparseSMat operator*(const SparseSMat &rhs) const; diff --git a/nonclass.cc b/nonclass.cc index 3ba7186..b70490d 100644 --- a/nonclass.cc +++ b/nonclass.cc @@ -238,71 +238,259 @@ linear_solve_do(a,&B[0],1,a.nrows(),det,n); //other version of linear solver based on gesvx - +//------------------------------------------------------------------------------ extern "C" void FORNAME(zgesvx)(const char *fact, const char *trans, const int *n, const int *nrhs, complex *A, const int *lda, complex *AF, const int *ldaf, const int *ipiv, char *equed, double *R,double *C, complex *B, const int *ldb, complex *X, const int *ldx, double *rcond, double *ferr, double *berr, complex *work, double *rwork, int *info); -extern "C" void FORNAME(dgesvx)(const char *fact, const char *trans, const int *n, const int *nrhs, double *A, const int *lda, double *AF, const int *ldaf, const int *ipiv, char *equed, double *R,double *C, double *B, const int *ldb, double *X, const int *ldx, double *rcond, double *ferr, double *berr, double *work, double *rwork, int *info); +extern "C" void FORNAME(dgesvx)(const char *fact, const char *trans, const int *n, const int *nrhs, double *A, const int *lda, double *AF, const int *ldaf, const int *ipiv, char *equed, double *R,double *C, double *B, const int *ldb, double *X, const int *ldx, double *rcond, double *ferr, double *berr, double *work, int *iwork, int *info); +//------------------------------------------------------------------------------ +// solves set of linear equations using dgesvx +// input: +// _A double precision matrix of dimension nn x mm, where min(nn, mm) >= n +// _B double prec. array dimensioned as nrhs x n +// _rhsCount nrhs - count of right hand sides +// _eqCount n - count of equations +// _eq use equilibration of matrix A before solving +// _saveA if set, do no overwrite A if equilibration in effect +// _rcond if not NULL, store the returned rcond value from dgesvx +// output: +// solution is stored in _B +// the info parameter of dgesvx is returned (see man dgesvx) +//------------------------------------------------------------------------------ +int linear_solve_x(NRMat &_A, double *_B, const int _rhsCount, const int _eqCount, const bool _eq, const bool _saveA, double *_rcond){ + const int A_rows = _A.nrows(); + const int A_cols = _A.ncols(); -int linear_solve_x_(NRMat > &A, complex *B, const bool eq, const int nrhs, const int ldb, const char trans) -{ - const int n_rows = A.nrows(); - const int n_cols = A.ncols(); + const char fact = _eq?'E':'N'; + const char trans = 'T';//because of c-order + char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const - if(n_rows != n_cols)laerror("non-squre matrix in linear_solve_x"); - const int n = n_rows; - const char fact = eq?'E':'N'; - char equed = 'B';//fact = 'N' => equed is an output argument + if(_eqCount < 0 || _eqCount > A_rows || _eqCount > A_cols || _rhsCount < 0){ + laerror("linear_solve_x: invalid input matrices"); + } - int info, lwork; - double rcond, ferr[nrhs], berr[nrhs], rwork[2*n]; - double R[n], C[n]; - complex *AF = new complex[n*n]; - complex *work = new complex[2*n]; - NRMat > X(n, nrhs); - int ipiv[n]; + double *A; + double * const _A_data = (double*)_A; - A.copyonwrite(); + int info; + const int nrhs = _rhsCount; + const int n = _eqCount; + int lda = A_cols; + const int ldaf = lda; - FORNAME(zgesvx)(&fact, &trans, &n_rows, &nrhs, \ - A[0], &n_rows, &AF[0], &n_rows, &ipiv[0], &equed, &R[0], &C[0], \ - &B[0], &ldb, X[0], &n_rows, &rcond, &ferr[0], &berr[0], &work[0], &rwork[0], &info); + double rcond; + double ferr[nrhs], berr[nrhs], work[4*n]; + double R[n], C[n]; - delete[] work; - delete[] AF; - memcpy(B, X[0], sizeof(complex)*n*nrhs); - return info; + int *const iwork = new int[n]; + int *const ipiv = new int[n]; + + double *X = new double[n*nrhs]; + double *AF = new double[ldaf*n]; + + A = _A_data; + if(_eq){ + if(_saveA){//store the corresponding submatrix of _A (not needed provided fact=='N') + A = new double[n*n]; + int offset1 = 0;int offset2 = 0; + for(register int i=0;i= n +// _B double prec. complex array dimensioned as nrhs x n +// _rhsCount nrhs - count of right hand sides +// _eqCount n - count of equations +// _eq use equilibration +// _saveA if set, do no overwrite A if equilibration in effect +// _rcond if not NULL, store the returned rcond value from dgesvx +// output: +// solution is stored in _B +// the info parameter of dgesvx is returned (see man dgesvx) +//------------------------------------------------------------------------------ +int linear_solve_x(NRMat > &_A, complex *_B, const int _rhsCount, const int _eqCount, const bool _eq, const bool _saveA, double *_rcond){ + const int A_rows = _A.nrows(); + const int A_cols = _A.ncols(); + + const char fact = _eq?'E':'N'; + const char trans = 'T';//because of c-order + char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const + + if(_eqCount < 0 || _eqCount > A_rows || _eqCount > A_cols || _rhsCount < 0){ + laerror("linear_solve_x: invalid input matrices"); + } + + complex *A; + complex * const _A_data = (complex*)_A; + + int info; + const int nrhs = _rhsCount; + const int n = _eqCount; + int lda = A_cols; + const int ldaf = lda; + + double rcond; + double ferr[nrhs], berr[nrhs]; + double R[n], C[n], rwork[2*n]; + complex work[2*n]; + + int *const ipiv = new int[n]; + + complex *X = new complex[n*nrhs]; + complex *AF = new complex[ldaf*n]; + + A = _A_data; + if(_eq){ + if(_saveA){//store the corresponding submatrix of _A (not needed provided fact=='N') + A = new complex[n*n]; + int offset1 = 0;int offset2 = 0; + for(register int i=0;i &A, double *B, const bool eq, const int nrhs, const int ldb, const char trans) -{ - const int n_rows = A.nrows(); - const int n_cols = A.ncols(); + if(_rcond)*_rcond = rcond; + cblas_zcopy(n*nrhs, X, 1, _B, 1);//store the solution - if(n_rows != n_cols)laerror("non-squre matrix in linear_solve_x"); - const int n = n_rows; - const char fact = eq?'E':'N'; - char equed = 'B';//fact = 'N' => equed is an output argument - - int info, lwork; - double rcond, ferr[nrhs], berr[nrhs], rwork[2*n]; - double R[n], C[n]; - double *AF = new double[n*n]; - double *work = new double[2*n]; - NRMat X(n, nrhs); - int ipiv[n]; - - A.copyonwrite(); - - FORNAME(dgesvx)(&fact, &trans, &n_rows, &nrhs, \ - A[0], &n_rows, &AF[0], &n_rows, &ipiv[0], &equed, &R[0], &C[0], \ - &B[0], &ldb, X[0], &n_rows, &rcond, &ferr[0], &berr[0], &work[0], &rwork[0], &info); - - delete[] work; - delete[] AF; - memcpy(B, X[0], sizeof(double)*n*nrhs); - return info; + delete[] ipiv; + delete[] AF;delete[] X; + if(_saveA){ + delete[] A; + } + return info; } +//------------------------------------------------------------------------------ +// for given square matrices A, B computes X = AB^{-1} as follows +// XB = A => B^TX^T = A^T +// input: +// _A double precision matrix of dimension nn x nn +// _B double prec. matrix of dimension nn x nn +// _useEq use equilibration suitable for badly conditioned matrices +// _rcond if not NULL, store the returned value of rcond fromd dgesvx +// output: +// solution is stored in _B +// the info parameter of dgesvx is returned (see man dgesvx) +//------------------------------------------------------------------------------ +int multiply_by_inverse(NRMat &_A, NRMat &_B, bool _useEq, double *_rcond){ + + const int n = _A.nrows(); + const int m = _A.ncols(); + if(n != m || n != _B.nrows() || n != _B.ncols()){ + laerror("multiply_by_inverse: incompatible matrices"); + } + + const char fact = _useEq?'E':'N'; + const char trans = 'N';//because of c-order + char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const + const int n2 = n*n; + + double * const A = (double*)_A; + double * const B = (double*)_B; + _B.copyonwrite();//even if fact='N', call copyonwrite because the solution is going to be stored in _B + int info; + double rcond; + double ferr[n], berr[n], work[4*n]; + double R[n], C[n]; + + int *const iwork = new int[n]; + int *const ipiv = new int[n]; + + double *X = new double[n2]; + double *AF = new double[n2]; + + FORNAME(dgesvx)(&fact, &trans, &n, &n, B, &n, AF, &n, &ipiv[0], &equed, &R[0], &C[0], A, &n, X, &n, &rcond, ferr, berr, work, iwork, &info); + + + if(_rcond)*_rcond = rcond; + cblas_dcopy(n2, X, 1, B, 1);//store the solution + + delete[] iwork;delete[] ipiv; + delete[] AF;delete[] X; + + return info; +} +//------------------------------------------------------------------------------ +// for given square matrices A, B computes X = AB^{-1} as follows +// XB = A => B^TX^T = A^T +// input: +// _A double precision matrix of dimension nn x nn +// _B double prec. matrix of dimension nn x nn +// _useEq use equilibration suitable for badly conditioned matrices +// _rcond if not NULL, store the returned value of rcond fromd zgesvx +// output: +// solution is stored in _B +// the info parameter of zgesvx is returned (see man zgesvx) +//------------------------------------------------------------------------------ +int multiply_by_inverse(NRMat > &_A, NRMat > &_B, bool _useEq, double *_rcond){ + + const int n = _A.nrows(); + const int m = _A.ncols(); + if(n != m || n != _B.nrows() || n != _B.ncols()){ + laerror("multiply_by_inverse: incompatible matrices"); + } + const int n2 = n*n; + + const char fact = _useEq?'E':'N'; + const char trans = 'N';//because of c-order + char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const + + complex * const A = (complex*)_A; + complex * const B = (complex*)_B; + _B.copyonwrite();//even if fact='N', call copyonwrite because the solution is going to be stored in _B + + int info; + double rcond; + double ferr[n], berr[n]; + double R[n], C[n], rwork[2*n]; + complex work[2*n]; + + int *const ipiv = new int[n]; + + complex *X = new complex[n2]; + complex *AF = new complex[n2]; + + FORNAME(zgesvx)(&fact, &trans, &n, &n, B, &n, AF, &n, &ipiv[0], &equed, &R[0], &C[0], A, &n, X, &n, &rcond, ferr, berr, work, rwork, &info); + + + if(_rcond)*_rcond = rcond; + cblas_zcopy(n2, X, 1, B, 1);//store the solution + + delete[] ipiv; + delete[] AF;delete[] X; + + return info; +} +//------------------------------------------------------------------------------ diff --git a/nonclass.h b/nonclass.h index a9795c0..e0cdb37 100644 --- a/nonclass.h +++ b/nonclass.h @@ -88,8 +88,8 @@ extern const NRVec diagofproduct(const NRMat &a, const NRMat &b,\ extern T trace2(const NRMat &a, const NRMat &b, bool trb=0); \ extern T trace2(const NRSMat &a, const NRSMat &b, const bool diagscaled=0);\ extern T trace2(const NRSMat &a, const NRMat &b, const bool diagscaled=0);\ -extern void linear_solve(NRMat &a, NRMat *b, double *det=0,int n=0); \ -extern void linear_solve(NRSMat &a, NRMat *b, double *det=0, int n=0); \ +extern void linear_solve(NRMat &a, NRMat *b, double *det=0,int n=0); /*solve Ax^T=b^T (b is nrhs x n) */ \ +extern void linear_solve(NRSMat &a, NRMat *b, double *det=0, int n=0); /*solve Ax^T=b^T (b is nrhs x n) */\ extern void linear_solve(NRMat &a, NRVec &b, double *det=0, int n=0); \ extern void linear_solve(NRSMat &a, NRVec &b, double *det=0, int n=0); \ extern void diagonalize(NRMat &a, NRVec::normtype> &w, const bool eivec=1, const bool corder=1, int n=0, NRMat *b=NULL, const int itype=1); \ @@ -184,51 +184,38 @@ return det; } -//extended linear solve routines -template -extern int linear_solve_x_(NRMat &A, T *B, const bool eq, const int nrhs, const int ldb, const char trans); - - -//solve Ax = b using zgesvx +//------------------------------------------------------------------------------ +// solves set of linear equations using gesvx +// input: +// A double precision matrix of dimension nn x mm, where min(nn, mm) >= n +// B double prec. array dimensioned as nrhs x n +// rhsCount nrhs - count of right hand sides +// eqCount n - count of equations +// eq use equilibration of matrix A before solving +// saveA if set, do no overwrite A if equilibration in effect +// rcond if not NULL, store the returned rcond value from dgesvx +// output: +// solution is stored in B +// the info parameter of gesvx is returned (see man dgesvx) +//------------------------------------------------------------------------------ template -inline int linear_solve_x(NRMat > &A, NRVec > &B, const bool eq) -{ -B.copyonwrite(); -return linear_solve_x_(A, &B[0], eq, 1, B.size(), 'T'); -} +int linear_solve_x(NRMat &A, T *B, const int rhsCount, const int eqCount, const bool eq, const bool saveA, double *rcond); -//solve AX = B using zgesvx +//------------------------------------------------------------------------------ +// for given square matrices A, B computes X = AB^{-1} as follows +// XB = A => B^TX^T = A^T +// input: +// _A double precision matrix of dimension nn x nn +// _B double prec. matrix of dimension nn x nn +// _useEq use equilibration suitable for badly conditioned matrices +// _rcond if not NULL, store the returned value of rcond fromd dgesvx +// output: +// solution is stored in _B +// the info parameter of dgesvx is returned (see man dgesvx) +//------------------------------------------------------------------------------ template -inline int linear_solve_x(NRMat > &A, NRMat > &B, const bool eq, const bool transpose=true) -{ -B.copyonwrite(); -if(transpose) B.transposeme();//because of corder -int info(0); -info = linear_solve_x_(A, B[0], eq, B.ncols(), B.nrows(), transpose?'T':'N'); -if(transpose) B.transposeme(); -return info; -} - - -#define multiply_by_inverse(P,Q,eq) linear_solve_x(P,Q,eq,false) -/* - * input: - * P,Q - general complex square matrices - * eq - use equilibration (man cgesvx) - * description: - * evaluates matrix expression QP^{-1} as - * Z = QP^{-1} - * ZP = Q - * P^TZ^T = Q^T - * Z is computed by solving this linear system instead of computing inverse - * of P followed by multiplication by Q - * returns: - * returns the info parameter of cgesvx - * result is stored in Q - */ - - +int multiply_by_inverse(NRMat &A, NRMat &B, bool useEq, double *rcond); //general submatrix, INDEX will typically be NRVec or even int* diff --git a/vec.h b/vec.h index 6a82060..58e5cd0 100644 --- a/vec.h +++ b/vec.h @@ -139,6 +139,7 @@ public: bool smaller(int i, int j) const {return LA_traits::smaller(v[i],v[j]);}; void swap(int i, int j) {T tmp; tmp=v[i]; v[i]=v[j]; v[j]=tmp;}; int sort(int direction=0, int from=0, int to= -1, int *perm=NULL); //sort, ascending by default, returns parity of permutation + NRVec & CallOnMe(T (*_F)(const T &) ) {copyonwrite(); for(int i=0; i