*** empty log message ***

This commit is contained in:
jiri
2010-02-25 20:47:01 +00:00
parent 03ef09deb8
commit df9ac6894b
4 changed files with 274 additions and 95 deletions

View File

@@ -238,71 +238,259 @@ linear_solve_do(a,&B[0],1,a.nrows(),det,n);
//other version of linear solver based on gesvx
//------------------------------------------------------------------------------
extern "C" void FORNAME(zgesvx)(const char *fact, const char *trans, const int *n, const int *nrhs, complex<double> *A, const int *lda, complex<double> *AF, const int *ldaf, const int *ipiv, char *equed, double *R,double *C, complex<double> *B, const int *ldb, complex<double> *X, const int *ldx, double *rcond, double *ferr, double *berr, complex<double> *work, double *rwork, int *info);
extern "C" void FORNAME(dgesvx)(const char *fact, const char *trans, const int *n, const int *nrhs, double *A, const int *lda, double *AF, const int *ldaf, const int *ipiv, char *equed, double *R,double *C, double *B, const int *ldb, double *X, const int *ldx, double *rcond, double *ferr, double *berr, double *work, double *rwork, int *info);
extern "C" void FORNAME(dgesvx)(const char *fact, const char *trans, const int *n, const int *nrhs, double *A, const int *lda, double *AF, const int *ldaf, const int *ipiv, char *equed, double *R,double *C, double *B, const int *ldb, double *X, const int *ldx, double *rcond, double *ferr, double *berr, double *work, int *iwork, int *info);
//------------------------------------------------------------------------------
// solves set of linear equations using dgesvx
// input:
// _A double precision matrix of dimension nn x mm, where min(nn, mm) >= n
// _B double prec. array dimensioned as nrhs x n
// _rhsCount nrhs - count of right hand sides
// _eqCount n - count of equations
// _eq use equilibration of matrix A before solving
// _saveA if set, do no overwrite A if equilibration in effect
// _rcond if not NULL, store the returned rcond value from dgesvx
// output:
// solution is stored in _B
// the info parameter of dgesvx is returned (see man dgesvx)
//------------------------------------------------------------------------------
int linear_solve_x(NRMat<double> &_A, double *_B, const int _rhsCount, const int _eqCount, const bool _eq, const bool _saveA, double *_rcond){
const int A_rows = _A.nrows();
const int A_cols = _A.ncols();
int linear_solve_x_(NRMat<complex<double> > &A, complex<double> *B, const bool eq, const int nrhs, const int ldb, const char trans)
{
const int n_rows = A.nrows();
const int n_cols = A.ncols();
const char fact = _eq?'E':'N';
const char trans = 'T';//because of c-order
char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const
if(n_rows != n_cols)laerror("non-squre matrix in linear_solve_x");
const int n = n_rows;
const char fact = eq?'E':'N';
char equed = 'B';//fact = 'N' => equed is an output argument
if(_eqCount < 0 || _eqCount > A_rows || _eqCount > A_cols || _rhsCount < 0){
laerror("linear_solve_x: invalid input matrices");
}
int info, lwork;
double rcond, ferr[nrhs], berr[nrhs], rwork[2*n];
double R[n], C[n];
complex<double> *AF = new complex<double>[n*n];
complex<double> *work = new complex<double>[2*n];
NRMat<complex<double> > X(n, nrhs);
int ipiv[n];
double *A;
double * const _A_data = (double*)_A;
A.copyonwrite();
int info;
const int nrhs = _rhsCount;
const int n = _eqCount;
int lda = A_cols;
const int ldaf = lda;
FORNAME(zgesvx)(&fact, &trans, &n_rows, &nrhs, \
A[0], &n_rows, &AF[0], &n_rows, &ipiv[0], &equed, &R[0], &C[0], \
&B[0], &ldb, X[0], &n_rows, &rcond, &ferr[0], &berr[0], &work[0], &rwork[0], &info);
double rcond;
double ferr[nrhs], berr[nrhs], work[4*n];
double R[n], C[n];
delete[] work;
delete[] AF;
memcpy(B, X[0], sizeof(complex<double>)*n*nrhs);
return info;
int *const iwork = new int[n];
int *const ipiv = new int[n];
double *X = new double[n*nrhs];
double *AF = new double[ldaf*n];
A = _A_data;
if(_eq){
if(_saveA){//store the corresponding submatrix of _A (not needed provided fact=='N')
A = new double[n*n];
int offset1 = 0;int offset2 = 0;
for(register int i=0;i<n;i++){
cblas_dcopy(n, _A_data + offset1, 1, A + offset2, 1);
offset1 += A_cols;
offset2 += n;
}
lda = n;//!!!
}else{
_A.copyonwrite();
}
}
FORNAME(dgesvx)(&fact, &trans, &n, &nrhs, A, &lda, AF, &ldaf, &ipiv[0], &equed, &R[0], &C[0], _B, &n, X, &n, &rcond, ferr, berr, work, iwork, &info);
if(_rcond)*_rcond = rcond;
cblas_dcopy(n*nrhs, X, 1, _B, 1);//store the solution
delete[] iwork;delete[] ipiv;
delete[] AF;delete[] X;
if(_saveA){
delete[] A;
}
return info;
}
//------------------------------------------------------------------------------
// solves set of linear equations using zgesvx
// input:
// _A double precision complex matrix of dimension nn x mm, where min(nn, mm) >= n
// _B double prec. complex array dimensioned as nrhs x n
// _rhsCount nrhs - count of right hand sides
// _eqCount n - count of equations
// _eq use equilibration
// _saveA if set, do no overwrite A if equilibration in effect
// _rcond if not NULL, store the returned rcond value from dgesvx
// output:
// solution is stored in _B
// the info parameter of dgesvx is returned (see man dgesvx)
//------------------------------------------------------------------------------
int linear_solve_x(NRMat<complex<double> > &_A, complex<double> *_B, const int _rhsCount, const int _eqCount, const bool _eq, const bool _saveA, double *_rcond){
const int A_rows = _A.nrows();
const int A_cols = _A.ncols();
const char fact = _eq?'E':'N';
const char trans = 'T';//because of c-order
char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const
if(_eqCount < 0 || _eqCount > A_rows || _eqCount > A_cols || _rhsCount < 0){
laerror("linear_solve_x: invalid input matrices");
}
complex<double> *A;
complex<double> * const _A_data = (complex<double>*)_A;
int info;
const int nrhs = _rhsCount;
const int n = _eqCount;
int lda = A_cols;
const int ldaf = lda;
double rcond;
double ferr[nrhs], berr[nrhs];
double R[n], C[n], rwork[2*n];
complex<double> work[2*n];
int *const ipiv = new int[n];
complex<double> *X = new complex<double>[n*nrhs];
complex<double> *AF = new complex<double>[ldaf*n];
A = _A_data;
if(_eq){
if(_saveA){//store the corresponding submatrix of _A (not needed provided fact=='N')
A = new complex<double>[n*n];
int offset1 = 0;int offset2 = 0;
for(register int i=0;i<n;i++){
cblas_zcopy(n, _A_data + offset1, 1, A + offset2, 1);
offset1 += A_cols;
offset2 += n;
}
lda = n;//!!!
}else{
_A.copyonwrite();
}
}
FORNAME(zgesvx)(&fact, &trans, &n, &nrhs, A, &lda, AF, &ldaf, &ipiv[0], &equed, &R[0], &C[0], _B, &n, X, &n, &rcond, ferr, berr, work, rwork, &info);
int linear_solve_x_(NRMat<double> &A, double *B, const bool eq, const int nrhs, const int ldb, const char trans)
{
const int n_rows = A.nrows();
const int n_cols = A.ncols();
if(_rcond)*_rcond = rcond;
cblas_zcopy(n*nrhs, X, 1, _B, 1);//store the solution
if(n_rows != n_cols)laerror("non-squre matrix in linear_solve_x");
const int n = n_rows;
const char fact = eq?'E':'N';
char equed = 'B';//fact = 'N' => equed is an output argument
int info, lwork;
double rcond, ferr[nrhs], berr[nrhs], rwork[2*n];
double R[n], C[n];
double *AF = new double[n*n];
double *work = new double[2*n];
NRMat<double> X(n, nrhs);
int ipiv[n];
A.copyonwrite();
FORNAME(dgesvx)(&fact, &trans, &n_rows, &nrhs, \
A[0], &n_rows, &AF[0], &n_rows, &ipiv[0], &equed, &R[0], &C[0], \
&B[0], &ldb, X[0], &n_rows, &rcond, &ferr[0], &berr[0], &work[0], &rwork[0], &info);
delete[] work;
delete[] AF;
memcpy(B, X[0], sizeof(double)*n*nrhs);
return info;
delete[] ipiv;
delete[] AF;delete[] X;
if(_saveA){
delete[] A;
}
return info;
}
//------------------------------------------------------------------------------
// for given square matrices A, B computes X = AB^{-1} as follows
// XB = A => B^TX^T = A^T
// input:
// _A double precision matrix of dimension nn x nn
// _B double prec. matrix of dimension nn x nn
// _useEq use equilibration suitable for badly conditioned matrices
// _rcond if not NULL, store the returned value of rcond fromd dgesvx
// output:
// solution is stored in _B
// the info parameter of dgesvx is returned (see man dgesvx)
//------------------------------------------------------------------------------
int multiply_by_inverse(NRMat<double> &_A, NRMat<double> &_B, bool _useEq, double *_rcond){
const int n = _A.nrows();
const int m = _A.ncols();
if(n != m || n != _B.nrows() || n != _B.ncols()){
laerror("multiply_by_inverse: incompatible matrices");
}
const char fact = _useEq?'E':'N';
const char trans = 'N';//because of c-order
char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const
const int n2 = n*n;
double * const A = (double*)_A;
double * const B = (double*)_B;
_B.copyonwrite();//even if fact='N', call copyonwrite because the solution is going to be stored in _B
int info;
double rcond;
double ferr[n], berr[n], work[4*n];
double R[n], C[n];
int *const iwork = new int[n];
int *const ipiv = new int[n];
double *X = new double[n2];
double *AF = new double[n2];
FORNAME(dgesvx)(&fact, &trans, &n, &n, B, &n, AF, &n, &ipiv[0], &equed, &R[0], &C[0], A, &n, X, &n, &rcond, ferr, berr, work, iwork, &info);
if(_rcond)*_rcond = rcond;
cblas_dcopy(n2, X, 1, B, 1);//store the solution
delete[] iwork;delete[] ipiv;
delete[] AF;delete[] X;
return info;
}
//------------------------------------------------------------------------------
// for given square matrices A, B computes X = AB^{-1} as follows
// XB = A => B^TX^T = A^T
// input:
// _A double precision matrix of dimension nn x nn
// _B double prec. matrix of dimension nn x nn
// _useEq use equilibration suitable for badly conditioned matrices
// _rcond if not NULL, store the returned value of rcond fromd zgesvx
// output:
// solution is stored in _B
// the info parameter of zgesvx is returned (see man zgesvx)
//------------------------------------------------------------------------------
int multiply_by_inverse(NRMat<complex<double> > &_A, NRMat<complex<double> > &_B, bool _useEq, double *_rcond){
const int n = _A.nrows();
const int m = _A.ncols();
if(n != m || n != _B.nrows() || n != _B.ncols()){
laerror("multiply_by_inverse: incompatible matrices");
}
const int n2 = n*n;
const char fact = _useEq?'E':'N';
const char trans = 'N';//because of c-order
char equed = 'B';//if fact=='N' then equed is an output argument, therefore not declared as const
complex<double> * const A = (complex<double>*)_A;
complex<double> * const B = (complex<double>*)_B;
_B.copyonwrite();//even if fact='N', call copyonwrite because the solution is going to be stored in _B
int info;
double rcond;
double ferr[n], berr[n];
double R[n], C[n], rwork[2*n];
complex<double> work[2*n];
int *const ipiv = new int[n];
complex<double> *X = new complex<double>[n2];
complex<double> *AF = new complex<double>[n2];
FORNAME(zgesvx)(&fact, &trans, &n, &n, B, &n, AF, &n, &ipiv[0], &equed, &R[0], &C[0], A, &n, X, &n, &rcond, ferr, berr, work, rwork, &info);
if(_rcond)*_rcond = rcond;
cblas_zcopy(n2, X, 1, B, 1);//store the solution
delete[] ipiv;
delete[] AF;delete[] X;
return info;
}
//------------------------------------------------------------------------------