*** empty log message ***

This commit is contained in:
jiri 2010-02-25 20:47:01 +00:00
parent 03ef09deb8
commit df9ac6894b
4 changed files with 274 additions and 95 deletions

3
mat.h
View File

@ -127,6 +127,9 @@ public:
void axpy(const T alpha, const NRMat &x); // this += a*x
inline const T amax() const;
const T trace() const;
NRMat & SwapRows();
NRMat & SwapCols();
NRMat & SwapRowsCols();
//members concerning sparse matrix
SparseSMat<T> operator*(const SparseSMat<T> &rhs) const;

View File

@ -238,71 +238,259 @@ linear_solve_do(a,&B[0],1,a.nrows(),det,n);
//other version of linear solver based on gesvx
//------------------------------------------------------------------------------
extern "C" void FORNAME(zgesvx)(const char *fact, const char *trans, const int *n, const int *nrhs, complex<double> *A, const int *lda, complex<double> *AF, const int *ldaf, const int *ipiv, char *equed, double *R,double *C, complex<double> *B, const int *ldb, complex<double> *X, const int *ldx, double *rcond, double *ferr, double *berr, complex<double> *work, double *rwork, int *info);
extern "C" void FORNAME(dgesvx)(const char *fact, const char *trans, const int *n, const int *nrhs, double *A, const int *lda, double *AF, const int *ldaf, const int *ipiv, char *equed, double *R,double *C, double *B, const int *ldb, double *X, const int *ldx, double *rcond, double *ferr, double *berr, double *work, double *rwork, int *info);
extern "C" void FORNAME(dgesvx)(const char *fact, const char *trans, const int *n, const int *nrhs, double *A, const int *lda, double *AF, const int *ldaf, const int *ipiv, char *equed, double *R,double *C, double *B, const int *ldb, double *X, const int *ldx, double *rcond, double *ferr, double *berr, double *work, int *iwork, int *info);
//------------------------------------------------------------------------------
// solves set of linear equations using dgesvx
// input:
// _A double precision matrix of dimension nn x mm, where min(nn, mm) >= n
// _B double prec. array dimensioned as nrhs x n
// _rhsCount nrhs - count of right hand sides
// _eqCount n - count of equations
// _eq use equilibration of matrix A before solving
// _saveA if set, do no overwrite A if equilibration in effect
// _rcond if not NULL, store the returned rcond value from dgesvx
// output:
// solution is stored in _B
// the info parameter of dgesvx is returned (see man dgesvx)
//------------------------------------------------------------------------------
int linear_solve_x(NRMat<double> &_A, double *_B, const int _rhsCount, const int _eqCount, const bool _eq, const bool _saveA, double *_rcond){
const int A_rows = _A.nrows();
const int A_cols = _A.ncols();
int linear_solve_x_(NRMat<complex<double> > &A, complex<double> *B, const bool eq, const int nrhs, const int ldb, const char trans)
{
const int n_rows = A.nrows();
const int n_cols = A.ncols();
const char fact = _eq?'E':'N';
const char trans = 'T';//because of c-order
char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const
if(n_rows != n_cols)laerror("non-squre matrix in linear_solve_x");
const int n = n_rows;
const char fact = eq?'E':'N';
char equed = 'B';//fact = 'N' => equed is an output argument
if(_eqCount < 0 || _eqCount > A_rows || _eqCount > A_cols || _rhsCount < 0){
laerror("linear_solve_x: invalid input matrices");
}
int info, lwork;
double rcond, ferr[nrhs], berr[nrhs], rwork[2*n];
double R[n], C[n];
complex<double> *AF = new complex<double>[n*n];
complex<double> *work = new complex<double>[2*n];
NRMat<complex<double> > X(n, nrhs);
int ipiv[n];
double *A;
double * const _A_data = (double*)_A;
A.copyonwrite();
int info;
const int nrhs = _rhsCount;
const int n = _eqCount;
int lda = A_cols;
const int ldaf = lda;
FORNAME(zgesvx)(&fact, &trans, &n_rows, &nrhs, \
A[0], &n_rows, &AF[0], &n_rows, &ipiv[0], &equed, &R[0], &C[0], \
&B[0], &ldb, X[0], &n_rows, &rcond, &ferr[0], &berr[0], &work[0], &rwork[0], &info);
double rcond;
double ferr[nrhs], berr[nrhs], work[4*n];
double R[n], C[n];
delete[] work;
delete[] AF;
memcpy(B, X[0], sizeof(complex<double>)*n*nrhs);
return info;
int *const iwork = new int[n];
int *const ipiv = new int[n];
double *X = new double[n*nrhs];
double *AF = new double[ldaf*n];
A = _A_data;
if(_eq){
if(_saveA){//store the corresponding submatrix of _A (not needed provided fact=='N')
A = new double[n*n];
int offset1 = 0;int offset2 = 0;
for(register int i=0;i<n;i++){
cblas_dcopy(n, _A_data + offset1, 1, A + offset2, 1);
offset1 += A_cols;
offset2 += n;
}
lda = n;//!!!
}else{
_A.copyonwrite();
}
}
FORNAME(dgesvx)(&fact, &trans, &n, &nrhs, A, &lda, AF, &ldaf, &ipiv[0], &equed, &R[0], &C[0], _B, &n, X, &n, &rcond, ferr, berr, work, iwork, &info);
if(_rcond)*_rcond = rcond;
cblas_dcopy(n*nrhs, X, 1, _B, 1);//store the solution
delete[] iwork;delete[] ipiv;
delete[] AF;delete[] X;
if(_saveA){
delete[] A;
}
return info;
}
//------------------------------------------------------------------------------
// solves set of linear equations using zgesvx
// input:
// _A double precision complex matrix of dimension nn x mm, where min(nn, mm) >= n
// _B double prec. complex array dimensioned as nrhs x n
// _rhsCount nrhs - count of right hand sides
// _eqCount n - count of equations
// _eq use equilibration
// _saveA if set, do no overwrite A if equilibration in effect
// _rcond if not NULL, store the returned rcond value from dgesvx
// output:
// solution is stored in _B
// the info parameter of dgesvx is returned (see man dgesvx)
//------------------------------------------------------------------------------
int linear_solve_x(NRMat<complex<double> > &_A, complex<double> *_B, const int _rhsCount, const int _eqCount, const bool _eq, const bool _saveA, double *_rcond){
const int A_rows = _A.nrows();
const int A_cols = _A.ncols();
const char fact = _eq?'E':'N';
const char trans = 'T';//because of c-order
char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const
if(_eqCount < 0 || _eqCount > A_rows || _eqCount > A_cols || _rhsCount < 0){
laerror("linear_solve_x: invalid input matrices");
}
complex<double> *A;
complex<double> * const _A_data = (complex<double>*)_A;
int info;
const int nrhs = _rhsCount;
const int n = _eqCount;
int lda = A_cols;
const int ldaf = lda;
double rcond;
double ferr[nrhs], berr[nrhs];
double R[n], C[n], rwork[2*n];
complex<double> work[2*n];
int *const ipiv = new int[n];
complex<double> *X = new complex<double>[n*nrhs];
complex<double> *AF = new complex<double>[ldaf*n];
A = _A_data;
if(_eq){
if(_saveA){//store the corresponding submatrix of _A (not needed provided fact=='N')
A = new complex<double>[n*n];
int offset1 = 0;int offset2 = 0;
for(register int i=0;i<n;i++){
cblas_zcopy(n, _A_data + offset1, 1, A + offset2, 1);
offset1 += A_cols;
offset2 += n;
}
lda = n;//!!!
}else{
_A.copyonwrite();
}
}
FORNAME(zgesvx)(&fact, &trans, &n, &nrhs, A, &lda, AF, &ldaf, &ipiv[0], &equed, &R[0], &C[0], _B, &n, X, &n, &rcond, ferr, berr, work, rwork, &info);
int linear_solve_x_(NRMat<double> &A, double *B, const bool eq, const int nrhs, const int ldb, const char trans)
{
const int n_rows = A.nrows();
const int n_cols = A.ncols();
if(_rcond)*_rcond = rcond;
cblas_zcopy(n*nrhs, X, 1, _B, 1);//store the solution
if(n_rows != n_cols)laerror("non-squre matrix in linear_solve_x");
const int n = n_rows;
const char fact = eq?'E':'N';
char equed = 'B';//fact = 'N' => equed is an output argument
int info, lwork;
double rcond, ferr[nrhs], berr[nrhs], rwork[2*n];
double R[n], C[n];
double *AF = new double[n*n];
double *work = new double[2*n];
NRMat<double> X(n, nrhs);
int ipiv[n];
A.copyonwrite();
FORNAME(dgesvx)(&fact, &trans, &n_rows, &nrhs, \
A[0], &n_rows, &AF[0], &n_rows, &ipiv[0], &equed, &R[0], &C[0], \
&B[0], &ldb, X[0], &n_rows, &rcond, &ferr[0], &berr[0], &work[0], &rwork[0], &info);
delete[] work;
delete[] AF;
memcpy(B, X[0], sizeof(double)*n*nrhs);
return info;
delete[] ipiv;
delete[] AF;delete[] X;
if(_saveA){
delete[] A;
}
return info;
}
//------------------------------------------------------------------------------
// for given square matrices A, B computes X = AB^{-1} as follows
// XB = A => B^TX^T = A^T
// input:
// _A double precision matrix of dimension nn x nn
// _B double prec. matrix of dimension nn x nn
// _useEq use equilibration suitable for badly conditioned matrices
// _rcond if not NULL, store the returned value of rcond fromd dgesvx
// output:
// solution is stored in _B
// the info parameter of dgesvx is returned (see man dgesvx)
//------------------------------------------------------------------------------
int multiply_by_inverse(NRMat<double> &_A, NRMat<double> &_B, bool _useEq, double *_rcond){
const int n = _A.nrows();
const int m = _A.ncols();
if(n != m || n != _B.nrows() || n != _B.ncols()){
laerror("multiply_by_inverse: incompatible matrices");
}
const char fact = _useEq?'E':'N';
const char trans = 'N';//because of c-order
char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const
const int n2 = n*n;
double * const A = (double*)_A;
double * const B = (double*)_B;
_B.copyonwrite();//even if fact='N', call copyonwrite because the solution is going to be stored in _B
int info;
double rcond;
double ferr[n], berr[n], work[4*n];
double R[n], C[n];
int *const iwork = new int[n];
int *const ipiv = new int[n];
double *X = new double[n2];
double *AF = new double[n2];
FORNAME(dgesvx)(&fact, &trans, &n, &n, B, &n, AF, &n, &ipiv[0], &equed, &R[0], &C[0], A, &n, X, &n, &rcond, ferr, berr, work, iwork, &info);
if(_rcond)*_rcond = rcond;
cblas_dcopy(n2, X, 1, B, 1);//store the solution
delete[] iwork;delete[] ipiv;
delete[] AF;delete[] X;
return info;
}
//------------------------------------------------------------------------------
// for given square matrices A, B computes X = AB^{-1} as follows
// XB = A => B^TX^T = A^T
// input:
// _A double precision matrix of dimension nn x nn
// _B double prec. matrix of dimension nn x nn
// _useEq use equilibration suitable for badly conditioned matrices
// _rcond if not NULL, store the returned value of rcond fromd zgesvx
// output:
// solution is stored in _B
// the info parameter of zgesvx is returned (see man zgesvx)
//------------------------------------------------------------------------------
int multiply_by_inverse(NRMat<complex<double> > &_A, NRMat<complex<double> > &_B, bool _useEq, double *_rcond){
const int n = _A.nrows();
const int m = _A.ncols();
if(n != m || n != _B.nrows() || n != _B.ncols()){
laerror("multiply_by_inverse: incompatible matrices");
}
const int n2 = n*n;
const char fact = _useEq?'E':'N';
const char trans = 'N';//because of c-order
char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const
complex<double> * const A = (complex<double>*)_A;
complex<double> * const B = (complex<double>*)_B;
_B.copyonwrite();//even if fact='N', call copyonwrite because the solution is going to be stored in _B
int info;
double rcond;
double ferr[n], berr[n];
double R[n], C[n], rwork[2*n];
complex<double> work[2*n];
int *const ipiv = new int[n];
complex<double> *X = new complex<double>[n2];
complex<double> *AF = new complex<double>[n2];
FORNAME(zgesvx)(&fact, &trans, &n, &n, B, &n, AF, &n, &ipiv[0], &equed, &R[0], &C[0], A, &n, X, &n, &rcond, ferr, berr, work, rwork, &info);
if(_rcond)*_rcond = rcond;
cblas_zcopy(n2, X, 1, B, 1);//store the solution
delete[] ipiv;
delete[] AF;delete[] X;
return info;
}
//------------------------------------------------------------------------------

View File

@ -88,8 +88,8 @@ extern const NRVec<T> diagofproduct(const NRMat<T> &a, const NRMat<T> &b,\
extern T trace2(const NRMat<T> &a, const NRMat<T> &b, bool trb=0); \
extern T trace2(const NRSMat<T> &a, const NRSMat<T> &b, const bool diagscaled=0);\
extern T trace2(const NRSMat<T> &a, const NRMat<T> &b, const bool diagscaled=0);\
extern void linear_solve(NRMat<T> &a, NRMat<T> *b, double *det=0,int n=0); \
extern void linear_solve(NRSMat<T> &a, NRMat<T> *b, double *det=0, int n=0); \
extern void linear_solve(NRMat<T> &a, NRMat<T> *b, double *det=0,int n=0); /*solve Ax^T=b^T (b is nrhs x n) */ \
extern void linear_solve(NRSMat<T> &a, NRMat<T> *b, double *det=0, int n=0); /*solve Ax^T=b^T (b is nrhs x n) */\
extern void linear_solve(NRMat<T> &a, NRVec<T> &b, double *det=0, int n=0); \
extern void linear_solve(NRSMat<T> &a, NRVec<T> &b, double *det=0, int n=0); \
extern void diagonalize(NRMat<T> &a, NRVec<LA_traits<T>::normtype> &w, const bool eivec=1, const bool corder=1, int n=0, NRMat<T> *b=NULL, const int itype=1); \
@ -184,51 +184,38 @@ return det;
}
//extended linear solve routines
template<class T>
extern int linear_solve_x_(NRMat<T> &A, T *B, const bool eq, const int nrhs, const int ldb, const char trans);
//solve Ax = b using zgesvx
//------------------------------------------------------------------------------
// solves set of linear equations using gesvx
// input:
// A double precision matrix of dimension nn x mm, where min(nn, mm) >= n
// B double prec. array dimensioned as nrhs x n
// rhsCount nrhs - count of right hand sides
// eqCount n - count of equations
// eq use equilibration of matrix A before solving
// saveA if set, do no overwrite A if equilibration in effect
// rcond if not NULL, store the returned rcond value from dgesvx
// output:
// solution is stored in B
// the info parameter of gesvx is returned (see man dgesvx)
//------------------------------------------------------------------------------
template<class T>
inline int linear_solve_x(NRMat<complex<double> > &A, NRVec<complex<double> > &B, const bool eq)
{
B.copyonwrite();
return linear_solve_x_(A, &B[0], eq, 1, B.size(), 'T');
}
int linear_solve_x(NRMat<T> &A, T *B, const int rhsCount, const int eqCount, const bool eq, const bool saveA, double *rcond);
//solve AX = B using zgesvx
//------------------------------------------------------------------------------
// for given square matrices A, B computes X = AB^{-1} as follows
// XB = A => B^TX^T = A^T
// input:
// _A double precision matrix of dimension nn x nn
// _B double prec. matrix of dimension nn x nn
// _useEq use equilibration suitable for badly conditioned matrices
// _rcond if not NULL, store the returned value of rcond fromd dgesvx
// output:
// solution is stored in _B
// the info parameter of dgesvx is returned (see man dgesvx)
//------------------------------------------------------------------------------
template<class T>
inline int linear_solve_x(NRMat<complex<double> > &A, NRMat<complex<double> > &B, const bool eq, const bool transpose=true)
{
B.copyonwrite();
if(transpose) B.transposeme();//because of corder
int info(0);
info = linear_solve_x_(A, B[0], eq, B.ncols(), B.nrows(), transpose?'T':'N');
if(transpose) B.transposeme();
return info;
}
#define multiply_by_inverse(P,Q,eq) linear_solve_x(P,Q,eq,false)
/*
* input:
* P,Q - general complex square matrices
* eq - use equilibration (man cgesvx)
* description:
* evaluates matrix expression QP^{-1} as
* Z = QP^{-1}
* ZP = Q
* P^TZ^T = Q^T
* Z is computed by solving this linear system instead of computing inverse
* of P followed by multiplication by Q
* returns:
* returns the info parameter of cgesvx
* result is stored in Q
*/
int multiply_by_inverse(NRMat<T> &A, NRMat<T> &B, bool useEq, double *rcond);
//general submatrix, INDEX will typically be NRVec<int> or even int*

1
vec.h
View File

@ -139,6 +139,7 @@ public:
bool smaller(int i, int j) const {return LA_traits<T>::smaller(v[i],v[j]);};
void swap(int i, int j) {T tmp; tmp=v[i]; v[i]=v[j]; v[j]=tmp;};
int sort(int direction=0, int from=0, int to= -1, int *perm=NULL); //sort, ascending by default, returns parity of permutation
NRVec & CallOnMe(T (*_F)(const T &) ) {copyonwrite(); for(int i=0; i<nn; ++i) v[i] = _F(v[i]); return *this;};
};
}//namespace