LA_library/diis.h

158 lines
4.0 KiB
C
Raw Normal View History

2005-02-18 23:08:15 +01:00
//DIIS convergence acceleration according to Pulay: Chem. Phys. Lett. 73, 393 (1980); J. Comp. Chem. 3,556 (1982)
2005-02-17 00:00:03 +01:00
#ifndef _DIIS_H_
#define _DIIS_H_
#include "vec.h"
#include "smat.h"
#include "mat.h"
#include "sparsemat.h"
#include "nonclass.h"
#include "la_traits.h"
#include "auxstorage.h"
2007-06-03 22:17:57 +02:00
//Pulay memorial book remarks - for numerical stabilization small addition to diagonal
#define DIISEPS 1e-9
2006-09-18 23:46:45 +02:00
// Typically, T is some solution vector in form of NRVec, NRMat, or NRSMat over double or complex<double> fields
2006-09-18 23:53:50 +02:00
// actually it can be anything what has operator=(const T&), clear(), dot() , axpy(), norm() and copyonwrite(), and LA_traits<T>::normtype and elementtype
2006-09-18 23:46:45 +02:00
// and get() and put() if external storage is requested
2005-02-17 00:00:03 +01:00
2007-06-04 00:49:29 +02:00
template<typename T, typename U>
2005-02-17 00:00:03 +01:00
class DIIS
{
int dim;
int aktdim;
bool incore;
int cyclicshift; //circular buffer of last dim vectors
typedef typename LA_traits<T>::elementtype Te;
2007-06-04 00:49:29 +02:00
typedef typename LA_traits<U>::elementtype Ue;
NRSMat<Ue> bmat;
AuxStorage<Te> *st;
AuxStorage<Ue> *errst;
T *stor;
U *errstor;
2005-02-17 00:00:03 +01:00
public:
2007-08-19 23:41:50 +02:00
DIIS() {dim=0; st=NULL; stor=NULL; errst=NULL; errstor=NULL;}; //for array of diis
2005-02-17 00:00:03 +01:00
DIIS(const int n, const bool core=1);
2006-09-18 23:46:45 +02:00
void setup(const int n, const bool core=1);
2005-02-17 00:00:03 +01:00
~DIIS();
2007-06-06 17:14:19 +02:00
typename LA_traits<U>::normtype extrapolate(T &vec, const U &errvec, bool verbose=false); //vec is input/output; returns square residual norm
2005-02-17 00:00:03 +01:00
};
2007-06-04 00:49:29 +02:00
template<typename T, typename U>
2007-06-06 17:14:19 +02:00
DIIS<T,U>::DIIS(const int n, const bool core) : dim(n), incore(core), bmat(n+1,n+1)
2005-02-17 00:00:03 +01:00
{
st=incore?NULL: new AuxStorage<Te>;
2007-06-04 00:49:29 +02:00
errst=incore?NULL: new AuxStorage<Ue>;
2005-02-17 00:00:03 +01:00
stor= incore? new T[dim] : NULL;
2007-06-04 00:49:29 +02:00
errstor= incore? new U[dim] : NULL;
2007-06-06 17:14:19 +02:00
bmat= (Ue)0; for(int i=1; i<=n; ++i) bmat(0,i) = (Ue)-1;
2005-02-17 00:00:03 +01:00
aktdim=cyclicshift=0;
2006-09-18 23:46:45 +02:00
}
2005-02-17 00:00:03 +01:00
2007-06-04 00:49:29 +02:00
template<typename T, typename U>
void DIIS<T,U>::setup(const int n, const bool core)
2006-09-18 23:46:45 +02:00
{
dim=n;
incore=core;
2007-06-06 17:14:19 +02:00
bmat.resize(n+1);
2006-09-18 23:46:45 +02:00
st=incore?NULL: new AuxStorage<Te>;
2007-06-04 00:49:29 +02:00
errst=incore?NULL: new AuxStorage<Ue>;
2006-09-18 23:46:45 +02:00
stor= incore? new T[dim] : NULL;
2007-06-04 00:49:29 +02:00
errstor= incore? new U[dim] : NULL;
2007-06-06 17:14:19 +02:00
bmat= (Ue)0; for(int i=1; i<=n; ++i) bmat(0,i) = (Ue)-1;
2006-09-18 23:46:45 +02:00
aktdim=cyclicshift=0;
2005-02-17 00:00:03 +01:00
}
2006-09-18 23:46:45 +02:00
2007-06-04 00:49:29 +02:00
template<typename T, typename U>
DIIS<T,U>::~DIIS()
2005-02-17 00:00:03 +01:00
{
if(st) delete st;
2007-06-03 22:17:57 +02:00
if(errst) delete errst;
2005-02-17 00:00:03 +01:00
if(stor) delete[] stor;
2007-06-03 22:17:57 +02:00
if(errstor) delete[] errstor;
2005-02-17 00:00:03 +01:00
}
2007-06-03 22:17:57 +02:00
2007-06-04 00:49:29 +02:00
template<typename T, typename U>
2007-06-06 17:14:19 +02:00
typename LA_traits<U>::normtype DIIS<T,U>::extrapolate(T &vec, const U &errvec, bool verbose)
2005-02-17 00:00:03 +01:00
{
2006-09-18 23:46:45 +02:00
if(!dim) laerror("attempt to extrapolate from uninitialized DIIS");
2005-02-17 00:00:03 +01:00
//if dim exceeded, shift
if(aktdim==dim)
{
cyclicshift=(cyclicshift+1)%dim;
2007-06-06 17:14:19 +02:00
for(int i=1; i<dim; ++i)
2005-02-17 00:00:03 +01:00
for(int j=1; j<=i; ++j)
bmat(i,j)=bmat(i+1,j+1);
}
else
++aktdim;
//store vector
2007-06-03 22:17:57 +02:00
if(incore)
{
2007-06-06 17:14:19 +02:00
stor[(aktdim-1+cyclicshift)%dim]|=vec;
errstor[(aktdim-1+cyclicshift)%dim]|=errvec;
2007-06-03 22:17:57 +02:00
}
else
{
st->put(vec,(aktdim-1+cyclicshift)%dim);
errst->put(errvec,(aktdim-1+cyclicshift)%dim);
}
2005-02-17 00:00:03 +01:00
2007-06-06 17:19:28 +02:00
if(aktdim==1) return (typename LA_traits<T>::normtype)1;
2005-02-17 01:30:09 +01:00
2007-06-03 22:17:57 +02:00
//calculate overlaps of the new error with old ones
typename LA_traits<T>::normtype norm=errvec.norm();
2007-06-06 17:14:19 +02:00
bmat(aktdim,aktdim)= norm*norm + DIISEPS;
2005-02-17 00:00:03 +01:00
if(incore)
2007-06-06 17:14:19 +02:00
for(int i=1; i<aktdim; ++i)
bmat(i,aktdim)=errvec.dot(errstor[(i+cyclicshift-1)%dim]);
2005-02-17 00:00:03 +01:00
else
{
2007-06-04 00:49:29 +02:00
U tmp = errvec; tmp.copyonwrite(); //copy dimensions
2007-06-06 17:14:19 +02:00
for(int i=1; i<aktdim; ++i)
2005-02-17 00:00:03 +01:00
{
2007-06-06 17:14:19 +02:00
errst->get(tmp,(i-1+cyclicshift)%dim);
bmat(i,aktdim)= errvec.dot(tmp);
2005-02-17 00:00:03 +01:00
}
}
2007-06-03 22:17:57 +02:00
2005-02-17 00:00:03 +01:00
//prepare rhs-solution vector
2007-06-06 17:14:19 +02:00
NRVec<Ue> rhs(dim+1);
2007-06-04 00:49:29 +02:00
rhs= (Ue)0; rhs[0]= (Ue)-1;
2005-02-17 00:00:03 +01:00
//solve for coefficients
2007-06-03 22:17:57 +02:00
//@@@@@@ implement checking for bad condition number and eliminating old vectors
//@@@ explicit solution - cf. remarks in Pulay memorial book
2005-02-17 00:00:03 +01:00
{
NRSMat<Te> amat=bmat;
2007-06-06 17:14:19 +02:00
linear_solve(amat,rhs,NULL,aktdim+1);
2005-02-17 00:00:03 +01:00
}
2007-06-06 17:14:19 +02:00
if(verbose) cout <<"DIIS coefficients: "<<rhs<<endl;
2005-02-17 00:00:03 +01:00
//build the new linear combination
2006-09-18 23:53:50 +02:00
vec.clear();
2005-02-17 00:00:03 +01:00
if(incore)
2007-06-06 17:14:19 +02:00
for(int i=1; i<=aktdim; ++i)
vec.axpy(rhs[i],stor[(i-1+cyclicshift)%dim]);
2005-02-17 00:00:03 +01:00
else
{
T tmp=vec; //copy dimensions
2007-06-06 17:14:19 +02:00
for(int i=1; i<=aktdim; ++i)
2005-02-17 00:00:03 +01:00
{
2007-06-06 17:14:19 +02:00
st->get(tmp,(i-1+cyclicshift)%dim);
2005-02-17 00:00:03 +01:00
vec.axpy(rhs[i],tmp);
}
}
2005-02-17 01:30:09 +01:00
return norm;
2005-02-17 00:00:03 +01:00
}
#endif