//////////////////////////////////////////////////////////////////////////// //LA traits classes and generally needed includes #ifndef _LA_TRAITS_INCL #define _LA_TRAITS_INCL using namespace std; #include #include #include #include #include "laerror.h" extern "C" { #include "cblas.h" } #ifdef _GLIBCPP_NO_TEMPLATE_EXPORT # define export #endif //forward declarations template class NRVec; template class NRMat; template class NRSMat; template class SparseMat; //let's do some simple template metaprogramming and preprocessing //to keep the thing general and compact typedef class scalar_false {}; typedef class scalar_true {}; //default is non-scalar template class isscalar { typedef scalar_false scalar_type; }; //specializations #define SCALAR(X) \ class isscalar {typedef scalar_true scalar_type;}; //declare what is scalar SCALAR(char) SCALAR(short) SCALAR(int) SCALAR(long) SCALAR(long long) SCALAR(unsigned char) SCALAR(unsigned short) SCALAR(unsigned int) SCALAR(unsigned long) SCALAR(unsigned long long) SCALAR(float) SCALAR(double) SCALAR(complex) SCALAR(complex) SCALAR(void *) #undef SCALAR //now declare the traits for scalars and for composed classes template struct LA_traits_aux {}; //complex scalars template struct LA_traits_aux, scalar_true> { typedef complex elementtype; typedef complex producttype; typedef C normtype; static normtype norm (const complex &x) {return abs(x);} static void axpy (complex &s, const complex &x, const complex &c) {s+=x*c;} static void get(int fd, complex &x, bool dimensions=0) {if(sizeof(complex)!=read(fd,&x,sizeof(complex))) laerror("read error");} static void put(int fd, const complex &x, bool dimensions=0) {if(sizeof(complex)!=write(fd,&x,sizeof(complex))) laerror("write error");} static void multiget(unsigned int n,int fd, complex *x, bool dimensions=0){if((ssize_t)(n*sizeof(complex))!=read(fd,x,n*sizeof(complex))) laerror("read error");} static void multiput(unsigned int n, int fd, const complex *x, bool dimensions=0) {if((ssize_t)(n*sizeof(complex))!=write(fd,x,n*sizeof(complex))) laerror("write error");} }; //non-complex scalars template struct LA_traits_aux { typedef C elementtype; typedef C producttype; typedef C normtype; static normtype norm (const C &x) {return abs(x);} static void axpy (C &s, const C &x, const C &c) {s+=x*c;} static void put(int fd, const C &x, bool dimensions=0) {if(sizeof(C)!=write(fd,&x,sizeof(C))) laerror("write error");} static void get(int fd, C &x, bool dimensions=0) {if(sizeof(C)!=read(fd,&x,sizeof(C))) laerror("read error");} static void multiput(unsigned int n,int fd, const C *x, bool dimensions=0){if((ssize_t)(n*sizeof(C))!=write(fd,x,n*sizeof(C))) laerror("write error");} static void multiget(unsigned int n, int fd, C *x, bool dimensions=0) {if((ssize_t)(n*sizeof(C))!=read(fd,x,n*sizeof(C))) laerror("read error");} }; //prepare for non-scalar classes template struct LA_traits; //forward declaration needed for template recursion #define generate_traits(X) \ template \ struct LA_traits_aux, scalar_false> { \ typedef C elementtype; \ typedef X producttype; \ typedef typename LA_traits::normtype normtype; \ static normtype norm (const X &x) {return x.norm();} \ static void axpy (X&s, const X &x, const C c) {s.axpy(c,x);} \ static void put(int fd, const C &x, bool dimensions=1) {x.put(fd,dimensions);} \ static void get(int fd, C &x, bool dimensions=1) {x.get(fd,dimensions);} \ static void multiput(unsigned int n,int fd, const C *x, bool dimensions=1) {for(unsigned int i=0; i struct LA_traits_aux, scalar_false> { typedef C elementtype; typedef NRMat producttype; typedef typename LA_traits::normtype normtype; static normtype norm (const NRSMat &x) {return x.norm();} static void axpy (NRSMat&s, const NRSMat &x, const C c) {s.axpy(c,x);} static void put(int fd, const C &x, bool dimensions=1) {x.put(fd,dimensions);} static void get(int fd, C &x, bool dimensions=1) {x.get(fd,dimensions);} static void multiput(unsigned int n,int fd, const C *x, bool dimensions=1) {for(unsigned int i=0; i struct LA_traits : LA_traits_aux::scalar_type> {}; #endif