diff --git a/efector.cc b/efector.cc new file mode 100644 index 0000000..62916a7 --- /dev/null +++ b/efector.cc @@ -0,0 +1,100 @@ +/* + LA: linear algebra C++ interface library + Copyright (C) 2008 Jiri Pittner or + This file contributed by + + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +*/ + +#include "la_traits.h" +#include +#include + +namespace LA { + +//-------------------------------------------------------------------------------- +class PrintNRCMat{ +private: +public: + std::ostringstream s; + PrintNRCMat(const PrintNRCMat &_fw); + PrintNRCMat(const NRCMat& _K, const int _prec = 5, const int _pl = 0, const int _pt = 0, const double _tol = 1.0e-12); + friend ostream& operator<<(ostream& os, const PrintNRCMat& fw){ + return (os << fw.s.str()); + } +}; +//-------------------------------------------------------------------------------- +PrintNRCMat::PrintNRCMat(const PrintNRCMat &_fw){ + this->s.str(_fw.s.str()); +} +//-------------------------------------------------------------------------------- +PrintNRCMat::PrintNRCMat(const NRCMat &_K, const int _prec, const int _pl, const int _pt, const double _tol){ +/* + formatuje realnou cast komplexni matice _K ve forme tabulky, napr. + ---------------------------------------------- + | +6.80375e+00 | +5.66198e+00 | +8.23295e+00 | + ---------------------------------------------- + | -3.29554e+00 | -4.44451e+00 | -4.52059e-01 | + ---------------------------------------------- + | -2.70431e+00 | +9.04459e+00 | +2.71423e+00 | + ---------------------------------------------- + + parametry: + _prec presnost pro cout << setprecision(...) + _pl left padding - odsazeni tabulky zleva + _pt top padding - odsazeni tabulky seshora + _tol cisla mensi v abs. hodnote nez _tol se zobrazi jako 0 +*/ + const int n = _K.nrows(); + const int m = _K.ncols(); + double val(0.0); + + + const int w = _prec + 7; + const int sirka = m*(w+3) + 1; + const int pl = (_pl>0&&_pl<10)?_pl:0; + const int pt = (_pt>0&&_pt<10)?_pt:0; + + const string vl = string(pl, ' '); + const string radek = string(sirka, '-'); + const string vypln = string(w, ' '); + + this->s << scientific << setprecision(_prec) << showpos; + this->s << string(pt, '\n'); + const string name = _K.getname(); + if(name != ""){ + this->s << vl << "matrix '" << name << "': " << endl; + } + this->s << vl << radek << endl; + for(register int i = 0;i < n;i++){ + this->s << vl; + for(register int j = 0;j < m; j++){ + val = (_K[i][j]).real(); + + this->s << setw(2) << right << "| " << setw(w) << left; + if(std::abs(val)<_tol){ + this->s << val; + }else{ + this->s << val; + } + this->s << setw(1) << " "; + } + this->s << "|" << endl; + this->s << vl << radek << endl; + } +} + + +}//namespace diff --git a/mat.cc b/mat.cc index 9375348..1ca0211 100644 --- a/mat.cc +++ b/mat.cc @@ -20,6 +20,7 @@ #include "mat.h" #include +#include #include #include #include @@ -401,13 +402,13 @@ template void NRMat::fscanf(FILE *f, const char *format) { int n, m; - if (std::fscanf(f, "%d %d", &n, &m) != 2) + if (::fscanf(f, "%d %d", &n, &m) != 2) laerror("cannot read matrix dimensions in Mat::fscanf()"); resize(n,m); T *p = *this; for(int i=0; i & NRMat::operator*=(const T &a) { copyonwrite(); #ifdef MATPTR - for (int i=0; i< nn*nn; i++) v[0][i] *= a; + for (int i=0; i< nn*mm; i++) v[0][i] *= a; #else - for (int i=0; i< nn*nn; i++) v[i] *= a; + for (int i=0; i< nn*mm; i++) v[i] *= a; #endif return *this; } @@ -608,9 +609,9 @@ NRMat & NRMat::operator+=(const NRMat &rhs) #endif copyonwrite(); #ifdef MATPTR - for (int i=0; i< nn*nn; i++) v[0][i] += rhs.v[0][i] ; + for (int i=0; i< nn*mm; i++) v[0][i] += rhs.v[0][i] ; #else - for (int i=0; i< nn*nn; i++) v[i] += rhs.v[i] ; + for (int i=0; i< nn*mm; i++) v[i] += rhs.v[i] ; #endif return *this; } @@ -656,9 +657,9 @@ NRMat & NRMat::operator-=(const NRMat &rhs) #endif copyonwrite(); #ifdef MATPTR - for (int i=0; i< nn*nn; i++) v[0][i] -= rhs.v[0][i] ; + for (int i=0; i< nn*mm; i++) v[0][i] -= rhs.v[0][i] ; #else - for (int i=0; i< nn*nn; i++) v[i] -= rhs.v[i] ; + for (int i=0; i< nn*mm; i++) v[i] -= rhs.v[i] ; #endif return *this; } diff --git a/mat.h b/mat.h index 36eabf1..3559881 100644 --- a/mat.h +++ b/mat.h @@ -66,6 +66,7 @@ public: NRMat & operator*=(const T &a); //multiply by a scalar NRMat & operator+=(const NRMat &rhs); NRMat & operator-=(const NRMat &rhs); + NRMat & operator^=(const NRMat &rhs); //Hadamard (element-wise) product NRMat & operator+=(const NRSMat &rhs); NRMat & operator-=(const NRSMat &rhs); const NRMat operator-() const; //unary minus @@ -649,6 +650,23 @@ public: +//Hadamard product +template +NRMat & NRMat::operator^=(const NRMat &rhs){ +#ifdef DEBUG + if (nn != rhs.nn || mm!= rhs.mm) + laerror("Mat ^= Mat of incompatible matrices"); +#endif + copyonwrite(); +#ifdef MATPTR + for (register int i=0; i< nn*mm; i++) v[0][i] *= rhs.v[0][i]; +#else + const int Dim = nn*mm; + for(register int i=0;i *A, const int *lda, complex *AF, const int *ldaf, const int *ipiv, char *equed, double *R,double *C, complex *B, const int *ldb, complex *X, const int *ldx, double *rcond, double *ferr, double *berr, complex *work, double *rwork, int *info); +extern "C" void FORNAME(dgesvx)(const char *fact, const char *trans, const int *n, const int *nrhs, double *A, const int *lda, double *AF, const int *ldaf, const int *ipiv, char *equed, double *R,double *C, double *B, const int *ldb, double *X, const int *ldx, double *rcond, double *ferr, double *berr, double *work, double *rwork, int *info); + +int linear_solve_x_(NRMat > &A, complex *B, const bool eq, const int nrhs, const int ldb, const char trans) +{ + const int n_rows = A.nrows(); + const int n_cols = A.ncols(); + + if(n_rows != n_cols)laerror("non-squre matrix in linear_solve_x"); + const int n = n_rows; + const char fact = eq?'E':'N'; + char equed = 'B';//fact = 'N' => equed is an output argument + + int info, lwork; + double rcond, ferr[nrhs], berr[nrhs], rwork[2*n]; + double R[n], C[n]; + complex *AF = new complex[n*n]; + complex *work = new complex[2*n]; + NRMat > X(n, nrhs); + int ipiv[n]; + + A.copyonwrite(); + + FORNAME(zgesvx)(&fact, &trans, &n_rows, &nrhs, \ + A[0], &n_rows, &AF[0], &n_rows, &ipiv[0], &equed, &R[0], &C[0], \ + &B[0], &ldb, X[0], &n_rows, &rcond, &ferr[0], &berr[0], &work[0], &rwork[0], &info); + + delete[] work; + delete[] AF; + memcpy(B, X[0], sizeof(complex)*n*nrhs); + return info; +} + + +int linear_solve_x_(NRMat &A, double *B, const bool eq, const int nrhs, const int ldb, const char trans) +{ + const int n_rows = A.nrows(); + const int n_cols = A.ncols(); + + if(n_rows != n_cols)laerror("non-squre matrix in linear_solve_x"); + const int n = n_rows; + const char fact = eq?'E':'N'; + char equed = 'B';//fact = 'N' => equed is an output argument + + int info, lwork; + double rcond, ferr[nrhs], berr[nrhs], rwork[2*n]; + double R[n], C[n]; + double *AF = new double[n*n]; + double *work = new double[2*n]; + NRMat X(n, nrhs); + int ipiv[n]; + + A.copyonwrite(); + + FORNAME(dgesvx)(&fact, &trans, &n_rows, &nrhs, \ + A[0], &n_rows, &AF[0], &n_rows, &ipiv[0], &equed, &R[0], &C[0], \ + &B[0], &ldb, X[0], &n_rows, &rcond, &ferr[0], &berr[0], &work[0], &rwork[0], &info); + + delete[] work; + delete[] AF; + memcpy(B, X[0], sizeof(double)*n*nrhs); + return info; +} + + + + extern "C" void FORNAME(dsyev)(const char *JOBZ, const char *UPLO, const int *N, double *A, const int *LDA, double *W, double *WORK, const int *LWORK, int *INFO); @@ -991,7 +1061,81 @@ else } +//various norms +extern "C" double FORNAME(zlange)( const char *NORM, const int *M, const int *N, complex *A, const int *LDA, double *WORK); +extern "C" double FORNAME(dlange)( const char *NORM, const int *M, const int *N, double *A, const int *LDA, double *WORK); +double MatrixNorm(NRMat > &A, const char norm) +{ + const char TypNorm = (tolower(norm) == 'o')?'I':'O'; //switch c-order/fortran-order + const int M = A.nrows(); + const int N = A.ncols(); + double work[M]; + const double ret = FORNAME(zlange)(&TypNorm, &M, &N, A[0], &M, &work[0]); + return ret; +} + +double MatrixNorm(NRMat &A, const char norm) +{ + const char TypNorm = (tolower(norm) == 'o')?'I':'O'; //switch c-order/fortran-order + const int M = A.nrows(); + const int N = A.ncols(); + double work[M]; + const double ret = FORNAME(dlange)(&TypNorm, &M, &N, A[0], &M, &work[0]); + return ret; +} + + + +//condition number +extern "C" void FORNAME(zgecon)( const char *norm, const int *n, complex *A, const int *LDA, const double *anorm, double *rcond, complex *work, double *rwork, int *info); +extern "C" void FORNAME(dgecon)( const char *norm, const int *n, double *A, const int *LDA, const double *anorm, double *rcond, double *work, double *rwork, int *info); + +double CondNumber(NRMat > &A, const char norm) +{ + const char TypNorm = (tolower(norm) == 'o')?'I':'O'; //switch c-order/fortran-order + const int N = A.nrows(); + double Norma(0.0), ret(0.0); + int info; + complex *work; + double *rwork; + + if(N != A.ncols()){ + laerror("nonsquare matrix in zgecon"); + return 0.0; + } + work = new complex[2*N]; + rwork = new double[2*N]; + + Norma = MatrixNorm(A, norm); + FORNAME(zgecon)(&TypNorm, &N, A[0], &N, &Norma, &ret, &work[0], &rwork[0], &info); + delete[] work; + delete[] rwork; + return ret; +} + +double CondNumber(NRMat &A, const char norm) +{ + const char TypNorm = (tolower(norm) == 'o')?'I':'O'; //switch c-order/fortran-order + const int N = A.nrows(); + double Norma(0.0), ret(0.0); + int info; + double *work; + double *rwork; + + if(N != A.ncols()){ + laerror("nonsquare matrix in zgecon"); + return 0.0; + } + work = new double[2*N]; + rwork = new double[2*N]; + + Norma = MatrixNorm(A, norm); + FORNAME(dgecon)(&TypNorm, &N, A[0], &N, &Norma, &ret, &work[0], &rwork[0], &info); + delete[] work; + delete[] rwork; + return ret; +} #ifdef obsolete diff --git a/nonclass.h b/nonclass.h index 740c933..a9795c0 100644 --- a/nonclass.h +++ b/nonclass.h @@ -154,6 +154,15 @@ const NRMat inverse(NRMat a, T *det=0) return result; } +//several matrix norms +template +typename LA_traits::normtype MatrixNorm(const MAT &A, const char norm); + +//condition number +template +typename LA_traits::normtype CondNumber(const MAT &A, const char norm); + + //general determinant template const typename LA_traits::elementtype determinant(MAT a)//passed by value @@ -175,6 +184,53 @@ return det; } +//extended linear solve routines +template +extern int linear_solve_x_(NRMat &A, T *B, const bool eq, const int nrhs, const int ldb, const char trans); + + +//solve Ax = b using zgesvx +template +inline int linear_solve_x(NRMat > &A, NRVec > &B, const bool eq) +{ +B.copyonwrite(); +return linear_solve_x_(A, &B[0], eq, 1, B.size(), 'T'); +} + + +//solve AX = B using zgesvx +template +inline int linear_solve_x(NRMat > &A, NRMat > &B, const bool eq, const bool transpose=true) +{ +B.copyonwrite(); +if(transpose) B.transposeme();//because of corder +int info(0); +info = linear_solve_x_(A, B[0], eq, B.ncols(), B.nrows(), transpose?'T':'N'); +if(transpose) B.transposeme(); +return info; +} + + +#define multiply_by_inverse(P,Q,eq) linear_solve_x(P,Q,eq,false) +/* + * input: + * P,Q - general complex square matrices + * eq - use equilibration (man cgesvx) + * description: + * evaluates matrix expression QP^{-1} as + * Z = QP^{-1} + * ZP = Q + * P^TZ^T = Q^T + * Z is computed by solving this linear system instead of computing inverse + * of P followed by multiplication by Q + * returns: + * returns the info parameter of cgesvx + * result is stored in Q + */ + + + + //general submatrix, INDEX will typically be NRVec or even int* //NOTE: in order to check consistency between nrows and rows in rows is a NRVec //some advanced metaprogramming would be necessary diff --git a/smat.cc b/smat.cc index 75ddccd..c2f5ea5 100644 --- a/smat.cc +++ b/smat.cc @@ -20,6 +20,7 @@ #include "smat.h" #include +#include #include #include #include @@ -160,13 +161,13 @@ template void NRSMat::fscanf(FILE *f, const char *format) { int n, m; - if (std::fscanf(f,"%d %d",&n,&m) != 2) + if (::fscanf(f,"%d %d",&n,&m) != 2) laerror("cannot read matrix dimensions in SMat::fscanf"); if (n != m) laerror("different dimensions of SMat"); resize(n); for (int i=0; i void SparseSMat::gemv(const T beta, NRVec &r, const char trans, const T alpha, const NRVec &x) const { if(nn!=r.size() || mm!= x.size()) laerror("incompatible matrix vector dimensions in SparseSMat::gemv"); -if(trans) laerror("transposition not implemented yet in SparseSMat::gemv"); +if(tolower(trans)!='n') laerror("transposition not implemented yet in SparseSMat::gemv"); r *= beta; if(alpha == (T)0) return; r.copyonwrite(); diff --git a/vec.cc b/vec.cc index 03adfe0..7efd8a4 100644 --- a/vec.cc +++ b/vec.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -124,10 +125,10 @@ void NRVec::fscanf(FILE *f, const char *format) { int n; - if(std::fscanf(f, "%d", &n) != 1) laerror("cannot read vector dimension"); + if(::fscanf(f, "%d", &n) != 1) laerror("cannot read vector dimension"); resize(n); for (int i=0; i