Compare commits

...

3 Commits

2 changed files with 30 additions and 4 deletions

View File

@ -174,6 +174,7 @@ extern "C" {
#include "cblas.h"
}
#include <stdarg.h>
#ifndef AVOID_DUPLICATE_CBLAS_XERBLA
extern "C" void cblas_xerbla(int p, const char *rout, const char *form, ...)
{
va_list argptr;
@ -189,6 +190,7 @@ extern "C" void cblas_xerbla(int p, const char *rout, const char *form, ...)
va_end(argptr);
laerror("terminating in cblas_xerbla");
}
#endif
extern "C" int cblas_errprn(int ierr, int info, char *form, ...) {
char msg0[1024], *msg;

View File

@ -26,7 +26,7 @@
#include "qsort.h"
#include "fortran.h"
#define IPIV_DEBUG
#undef IPIV_DEBUG
namespace LA {
@ -155,7 +155,11 @@ static void linear_solve_do(NRMat<double> &A, double *B, const int nrhs, const i
for (int i=0; i<n; ++i) {double t=A[i][i]; if(!finite(t) || std::abs(t) < EPSDET ) {*det=0.; break;} else *det *=t;}
//find out whether ipiv are numbered from 0 or from 1
int shift=1;
for (int i=0; i<n; ++i) if(ipiv[i]==0) shift=0;
for (int i=0; i<n; ++i)
{
if(ipiv[i]==0) shift=0;
if(ipiv[i]<0 || ipiv[i]>n) laerror("problem with ipiv in clapack_dgesv");
}
#ifdef IPIV_DEBUG
std::cout <<"shift = "<<shift<<std::endl;
#endif
@ -254,6 +258,7 @@ extern "C" void FORNAME(zgesv)(const int *N, const int *NRHS, double *A, const i
void linear_solve(NRMat< std::complex<double> > &A, NRMat< std::complex<double> > *B, std::complex<double> *det, int n)
{
int r, *ipiv;
int iswap=0;
if (A.nrows() != A.ncols()) laerror("linear_solve() call for non-square matrix");
if (B && A.nrows() != B->ncols()) laerror("incompatible matrices in linear_solve()");
@ -270,12 +275,31 @@ void linear_solve(NRMat< std::complex<double> > &A, NRMat< std::complex<double>
delete[] ipiv;
laerror("illegal argument in lapack_gesv");
}
if (det && r>=0) {
if (det && r==0) {
*det = A[0][0];
for (int i=1; i<A.nrows(); ++i) *det *= A[i][i];
int shift=1;
for (int i=0; i<n; ++i)
{
if(ipiv[i]==0) shift=0;
if(ipiv[i]<0 || ipiv[i]>n) laerror("problem with ipiv in zgesv");
}
#ifdef IPIV_DEBUG
std::cout <<"shift = "<<shift<<std::endl;
#endif
//
//change sign of det by parity of ipiv permutation
for (int i=0; i<A.nrows(); ++i) *det = -(*det);
if(det) for (int i=0; i<A.nrows(); ++i) if(i+shift != ipiv[i]) {*det = -(*det); ++iswap;}
}
if(det && r>0) *det = 0;
#ifdef IPIV_DEBUG
std::cout <<"iswap = "<<iswap<<std::endl;
std::cout <<"ipiv = ";
for (int i=0; i<n; ++i) std::cout <<ipiv[i]<<" ";
std::cout <<std::endl;
#endif
delete [] ipiv;
if (r>0 && B) laerror("singular matrix in zgesv");
}