improved diagnostics in clapack_dgesv replacement and bugfix in wrapper of dspsv

This commit is contained in:
2026-03-04 10:55:59 +01:00
parent dfa9369779
commit 061880fb9f
3 changed files with 29 additions and 13 deletions

View File

@@ -657,8 +657,9 @@ int clapack_dgesv(const CBLAS_ORDER Order, const int N, const int NRHS,
const FINT nrhstmp=NRHS;
const FINT ldatmp=lda;
const FINT ldbtmp=ldb;
FINT ipivtmp=*ipiv;
FORNAME(dgesv) (&ntmp,&nrhstmp,A,&ldatmp,&ipivtmp,B,&ldbtmp,&INFO);
FINT ipivtmp[N];
FORNAME(dgesv) (&ntmp,&nrhstmp,A,&ldatmp,ipivtmp,B,&ldbtmp,&INFO);
for(int i=0; i<N; ++i) ipiv[i]=ipivtmp[i];
#else
FORNAME(dgesv) (&N,&NRHS,A,&lda,ipiv,B,&ldb,&INFO);
#endif
@@ -672,6 +673,7 @@ int clapack_sgesv(const CBLAS_ORDER Order, const int N, const int NRHS,
float *A, const int lda, int *ipiv,
float *B, const int ldb)
{
std::cout <<"In my clapack_sgesv\n";
FINT INFO=0;
if(Order!=CblasRowMajor) laerror("CblasRowMajor order asserted");
//B should be in the same physical order, just transpose A in place and the LU result on output
@@ -681,8 +683,9 @@ int clapack_sgesv(const CBLAS_ORDER Order, const int N, const int NRHS,
const FINT nrhstmp=NRHS;
const FINT ldatmp=lda;
const FINT ldbtmp=ldb;
FINT ipivtmp=*ipiv;
FORNAME(sgesv) (&ntmp,&nrhstmp,A,&ldatmp,&ipivtmp,B,&ldbtmp,&INFO);
FINT ipivtmp[N];
FORNAME(sgesv) (&ntmp,&nrhstmp,A,&ldatmp,ipivtmp,B,&ldbtmp,&INFO);
for(int i=0; i<N; ++i) ipiv[i]=ipivtmp[i];
#else
FORNAME(sgesv) (&N,&NRHS,A,&lda,ipiv,B,&ldb,&INFO);
#endif

View File

@@ -26,7 +26,7 @@
#include "qsort.h"
#include "fortran.h"
#undef IPIV_DEBUG
//#define IPIV_DEBUG
namespace LA {
@@ -144,7 +144,11 @@ static void linear_solve_do(NRMat<double> &A, double *B, const int nrhs, const i
if (n==A.nrows() && A.nrows() != A.ncols()) laerror("linear_solve() call for non-square matrix");
A.copyonwrite();
ipiv = new int[A.nrows()];
r = clapack_dgesv(CblasRowMajor, n, nrhs, A[0], A.ncols(), ipiv, B , ldb);
#ifdef IPIV_DEBUG
for(int i=0; i<A.nrows(); ++i) ipiv[i]=123456789;
#endif
r = clapack_dgesv(CblasRowMajor, n, nrhs, &A(0,0), A.ncols(), ipiv, B , ldb);
// std::cout <<"A after clapack_dgesv = "<<A<<std::endl;
if (r < 0) {
delete[] ipiv;
laerror("illegal argument in lapack_gesv");
@@ -158,7 +162,11 @@ static void linear_solve_do(NRMat<double> &A, double *B, const int nrhs, const i
for (int i=0; i<n; ++i)
{
if(ipiv[i]==0) shift=0;
if(ipiv[i]<0 || ipiv[i]>n) laerror("problem with ipiv in clapack_dgesv");
if(ipiv[i]<0 || ipiv[i]>n)
{
std::cout <<"IPIV["<<i<<"] = "<<ipiv[i]<<std::endl;
laerror("problem with ipiv in clapack_dgesv");
}
}
#ifdef IPIV_DEBUG
std::cout <<"shift = "<<shift<<std::endl;
@@ -216,13 +224,14 @@ static void linear_solve_do(NRSMat<double> &a, double *b, const int nrhs, const
#else
FORNAME(dspsv)(&U, &n, &nrhs, a, ipiv, b, &ldb,&r);
#endif
// std::cout <<"A after dspsv = "<<a<<std::endl;
if (r < 0) {
delete[] ipiv;
laerror("illegal argument in spsv() call of linear_solve()");
}
if (det && r == 0) {
*det = 1.;
for (int i=1; i<n; i++) {double t=a(i,i); if(!finite(t) || std::abs(t) < EPSDET ) {*det=0.; break;} else *det *= t;}
for (int i=0; i<n; i++) {double t=a(i,i); if(!finite(t) || std::abs(t) < EPSDET ) {*det=0.; break;} else *det *= t;}
//do not use ipiv, since the permutation matrix occurs twice in the decomposition and signs thus cancel (man dspsv)
}
if (det && r>0) *det = 0;
@@ -282,7 +291,11 @@ void linear_solve(NRMat< std::complex<double> > &A, NRMat< std::complex<double>
for (int i=0; i<n; ++i)
{
if(ipiv[i]==0) shift=0;
if(ipiv[i]<0 || ipiv[i]>n) laerror("problem with ipiv in zgesv");
if(ipiv[i]<0 || ipiv[i]>n)
{
std::cout <<"IPIV["<<i<<"] = "<<ipiv[i]<<std::endl;
laerror("problem with ipiv in zgesv");
}
}
#ifdef IPIV_DEBUG
std::cout <<"shift = "<<shift<<std::endl;

8
t.cc
View File

@@ -1084,13 +1084,13 @@ NRMat<complex<double> > b=exp(a);
cout <<b;
}
if(0)
if(1)
{
int n;
double d;
cin >>n;
//NRMat<double> a(n,n);
NRSMat<double> a(n);
NRMat<double> a(n,n);
//NRSMat<double> a(n);
for(int i=0;i<n;++i) for(int j=0;j<=i;++j)
{
a(j,i)=a(i,j)=RANDDOUBLE()*(i==j?10.:1.);
@@ -4732,7 +4732,7 @@ cout <<"Error = "<<(a-aa).norm()<<endl;
}
if(1)
if(0)
{
int n;
cin>>n;