LA_library/diis.h

128 lines
2.7 KiB
C
Raw Normal View History

2005-02-18 23:08:15 +01:00
//DIIS convergence acceleration according to Pulay: Chem. Phys. Lett. 73, 393 (1980); J. Comp. Chem. 3,556 (1982)
2005-02-17 00:00:03 +01:00
#ifndef _DIIS_H_
#define _DIIS_H_
#include "vec.h"
#include "smat.h"
#include "mat.h"
#include "sparsemat.h"
#include "nonclass.h"
#include "la_traits.h"
#include "auxstorage.h"
// T is some solution vector in form of NRVec, NRMat, or NRSMat over double or complex<double> fields
template<typename T>
class DIIS
{
int dim;
int aktdim;
bool incore;
int cyclicshift; //circular buffer of last dim vectors
typedef typename LA_traits<T>::elementtype Te;
NRSMat<Te> bmat;
AuxStorage<Te> *st;
T *stor;
public:
DIIS(const int n, const bool core=1);
~DIIS();
2005-02-17 01:30:09 +01:00
typename LA_traits<T>::normtype extrapolate(T &vec); //vec is input/output; returns square residual norm
2005-02-17 00:00:03 +01:00
};
template<typename T>
2005-02-17 01:30:09 +01:00
DIIS<T>::DIIS(const int n, const bool core) : dim(n), incore(core), bmat(n,n)
2005-02-17 00:00:03 +01:00
{
st=incore?NULL: new AuxStorage<Te>;
stor= incore? new T[dim] : NULL;
2005-02-17 01:30:09 +01:00
bmat= (Te)0; for(int i=1; i<n; ++i) bmat(0,i) = (Te)-1;
2005-02-17 00:00:03 +01:00
aktdim=cyclicshift=0;
}
template<typename T>
DIIS<T>::~DIIS()
{
if(st) delete st;
if(stor) delete[] stor;
}
template<typename T>
2005-02-17 01:30:09 +01:00
typename LA_traits<T>::normtype DIIS<T>::extrapolate(T &vec)
2005-02-17 00:00:03 +01:00
{
//if dim exceeded, shift
if(aktdim==dim)
{
cyclicshift=(cyclicshift+1)%dim;
2005-02-17 01:30:09 +01:00
for(int i=1; i<dim-1; ++i)
2005-02-17 00:00:03 +01:00
for(int j=1; j<=i; ++j)
bmat(i,j)=bmat(i+1,j+1);
}
else
++aktdim;
//store vector
if(incore) stor[(aktdim-1+cyclicshift)%dim]=vec;
2005-02-17 01:30:09 +01:00
else st->put(vec,(aktdim-1+cyclicshift)%dim);
2005-02-17 00:00:03 +01:00
2005-02-17 01:30:09 +01:00
if(aktdim==1) return (typename LA_traits<T>::normtype)1000000000;
//calculate difference;
vec.copyonwrite();
if(incore) vec -= stor[(aktdim-2+cyclicshift)%dim];
else
{
T tmp=vec;
st->get(tmp,(aktdim-2+cyclicshift)%dim);
vec -= tmp;
}
//calculate overlaps of differences (if storage is cheap, they could rather be stored than recomputed)
typename LA_traits<T>::normtype norm=vec.norm();
bmat(aktdim-1,aktdim-1)= norm*norm;
2005-02-17 00:00:03 +01:00
if(incore)
2005-02-17 01:30:09 +01:00
for(int i=1; i<aktdim-1; ++i)
bmat(i,aktdim-1)=vec.dot(stor[(i+cyclicshift)%dim] - stor[(i-1+cyclicshift)%dim]);
2005-02-17 00:00:03 +01:00
else
{
2005-02-17 01:30:09 +01:00
T tmp=vec;
T tmp2=vec; //copy dimensions
st->get(tmp2,(0+cyclicshift)%dim);
for(int i=1; i<aktdim-1; ++i)
2005-02-17 00:00:03 +01:00
{
2005-02-17 01:30:09 +01:00
st->get(tmp,(i+cyclicshift)%dim);
tmp2 -= tmp;
bmat(i,aktdim-1)= -vec.dot(tmp2);
tmp2=tmp;
2005-02-17 00:00:03 +01:00
}
}
//prepare rhs-solution vector
2005-02-17 01:30:09 +01:00
NRVec<Te> rhs(dim);
2005-02-17 00:00:03 +01:00
rhs= (Te)0; rhs[0]= (Te)-1;
//solve for coefficients
{
NRSMat<Te> amat=bmat;
2005-02-17 01:30:09 +01:00
linear_solve(amat,rhs,NULL,aktdim);
2005-02-17 00:00:03 +01:00
}
//build the new linear combination
2005-02-17 01:30:09 +01:00
vec = (Te)0;
2005-02-17 00:00:03 +01:00
if(incore)
2005-02-17 01:30:09 +01:00
for(int i=1; i<aktdim; ++i) vec.axpy(rhs[i],stor[(i+cyclicshift)%dim]);
2005-02-17 00:00:03 +01:00
else
{
T tmp=vec; //copy dimensions
for(int i=1; i<aktdim; ++i)
{
2005-02-17 01:30:09 +01:00
st->get(tmp,(i+cyclicshift)%dim);
2005-02-17 00:00:03 +01:00
vec.axpy(rhs[i],tmp);
}
}
2005-02-17 01:30:09 +01:00
return norm;
2005-02-17 00:00:03 +01:00
}
#endif