LA_library/la_traits.h

169 lines
6.4 KiB
C
Raw Normal View History

2004-03-17 04:07:21 +01:00
////////////////////////////////////////////////////////////////////////////
2005-02-18 23:08:15 +01:00
//LA traits classes and generally needed includes
2004-03-17 04:07:21 +01:00
#ifndef _LA_TRAITS_INCL
#define _LA_TRAITS_INCL
2005-02-18 23:08:15 +01:00
using namespace std;
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <complex>
#include "laerror.h"
2005-09-06 17:55:07 +02:00
#ifdef NONCBLAS
#include "noncblas.h"
#else
2005-02-18 23:08:15 +01:00
extern "C" {
#include "cblas.h"
}
2005-09-06 17:55:07 +02:00
#endif
2005-02-18 23:08:15 +01:00
#ifdef _GLIBCPP_NO_TEMPLATE_EXPORT
# define export
#endif
2005-02-14 01:10:07 +01:00
//forward declarations
template<typename C> class NRVec;
template<typename C> class NRMat;
template<typename C> class NRSMat;
template<typename C> class SparseMat;
//let's do some simple template metaprogramming and preprocessing
//to keep the thing general and compact
typedef class scalar_false {};
typedef class scalar_true {};
//default is non-scalar
template<typename C>
class isscalar {
typedef scalar_false scalar_type;
};
//specializations
#define SCALAR(X) \
2005-11-20 14:46:00 +01:00
template<>\
2005-02-14 01:10:07 +01:00
class isscalar<X> {typedef scalar_true scalar_type;};
//declare what is scalar
SCALAR(char)
SCALAR(short)
SCALAR(int)
SCALAR(long)
SCALAR(long long)
SCALAR(unsigned char)
SCALAR(unsigned short)
SCALAR(unsigned int)
SCALAR(unsigned long)
SCALAR(unsigned long long)
SCALAR(float)
SCALAR(double)
SCALAR(complex<float>)
SCALAR(complex<double>)
SCALAR(void *)
#undef SCALAR
//now declare the traits for scalars and for composed classes
2005-09-06 17:55:07 +02:00
//NOTE! methods in traits classes have to be declared static,
//since the class itself is never instantiated.
//for performance, it can be also inlined at the same time
2005-02-14 01:10:07 +01:00
template<typename C, typename Scalar> struct LA_traits_aux {};
2005-09-06 17:55:07 +02:00
//TRAITS SPECIALIZATIONS
2005-02-14 01:10:07 +01:00
//complex scalars
template<typename C>
struct LA_traits_aux<complex<C>, scalar_true> {
typedef complex<C> elementtype;
typedef complex<C> producttype;
typedef C normtype;
2005-09-06 17:55:07 +02:00
static inline bool gencmp(const complex<C> *x, const complex<C> *y, int n) {return memcmp(x,y,n*sizeof(complex<C>));}
static bool bigger(const complex<C> &x, const complex<C> &y) {laerror("complex comparison undefined"); return false;}
static bool smaller(const complex<C> &x, const complex<C> &y) {laerror("complex comparison undefined"); return false;}
static inline normtype norm (const complex<C> &x) {return abs(x);}
static inline void axpy (complex<C> &s, const complex<C> &x, const complex<C> &c) {s+=x*c;}
2005-09-11 22:04:24 +02:00
static inline void get(int fd, complex<C> &x, bool dimensions=0, bool transp=0) {if(sizeof(complex<C>)!=read(fd,&x,sizeof(complex<C>))) laerror("read error");}
static inline void put(int fd, const complex<C> &x, bool dimensions=0, bool transp=0) {if(sizeof(complex<C>)!=write(fd,&x,sizeof(complex<C>))) laerror("write error");}
2005-02-14 01:10:07 +01:00
static void multiget(unsigned int n,int fd, complex<C> *x, bool dimensions=0){if((ssize_t)(n*sizeof(complex<C>))!=read(fd,x,n*sizeof(complex<C>))) laerror("read error");}
static void multiput(unsigned int n, int fd, const complex<C> *x, bool dimensions=0) {if((ssize_t)(n*sizeof(complex<C>))!=write(fd,x,n*sizeof(complex<C>))) laerror("write error");}
};
//non-complex scalars
template<typename C>
struct LA_traits_aux<C, scalar_true> {
2004-03-17 04:07:21 +01:00
typedef C elementtype;
typedef C producttype;
2005-02-14 01:10:07 +01:00
typedef C normtype;
2005-09-06 17:55:07 +02:00
static inline bool gencmp(const C *x, const C *y, int n) {return memcmp(x,y,n*sizeof(C));}
static inline bool bigger(const C &x, const C &y) {return x>y;}
static inline bool smaller(const C &x, const C &y) {return x<y;}
static inline normtype norm (const C &x) {return abs(x);}
static inline void axpy (C &s, const C &x, const C &c) {s+=x*c;}
2005-09-11 22:04:24 +02:00
static inline void put(int fd, const C &x, bool dimensions=0, bool transp=0) {if(sizeof(C)!=write(fd,&x,sizeof(C))) laerror("write error");}
static inline void get(int fd, C &x, bool dimensions=0, bool transp=0) {if(sizeof(C)!=read(fd,&x,sizeof(C))) laerror("read error");}
2005-02-14 01:10:07 +01:00
static void multiput(unsigned int n,int fd, const C *x, bool dimensions=0){if((ssize_t)(n*sizeof(C))!=write(fd,x,n*sizeof(C))) laerror("write error");}
static void multiget(unsigned int n, int fd, C *x, bool dimensions=0) {if((ssize_t)(n*sizeof(C))!=read(fd,x,n*sizeof(C))) laerror("read error");}
2004-03-17 04:07:21 +01:00
};
2005-09-06 17:55:07 +02:00
//non-scalars except smat
2005-02-14 01:10:07 +01:00
template<typename C>
struct LA_traits; //forward declaration needed for template recursion
#define generate_traits(X) \
template<typename C> \
struct LA_traits_aux<X<C>, scalar_false> { \
typedef C elementtype; \
typedef X<C> producttype; \
typedef typename LA_traits<C>::normtype normtype; \
2005-09-06 17:55:07 +02:00
static bool gencmp(const C *x, const C *y, int n) {for(int i=0; i<n; ++i) if(x[i]!=y[i]) return true; return false;} \
static inline bool bigger(const C &x, const C &y) {return x>y;} \
static inline bool smaller(const C &x, const C &y) {return x<y;} \
static inline normtype norm (const X<C> &x) {return x.norm();} \
static inline void axpy (X<C>&s, const X<C> &x, const C c) {s.axpy(c,x);} \
2005-09-11 22:04:24 +02:00
static void put(int fd, const C &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions,transp);} \
static void get(int fd, C &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions,transp);} \
2005-02-14 01:10:07 +01:00
static void multiput(unsigned int n,int fd, const C *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].put(fd,dimensions);} \
static void multiget(unsigned int n,int fd, C *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].get(fd,dimensions);} \
2004-03-17 04:07:21 +01:00
};
2005-02-14 01:10:07 +01:00
//non-scalar types defined in this library
generate_traits(NRMat)
generate_traits(NRVec)
generate_traits(SparseMat)
#undef generate_traits
2005-09-06 17:55:07 +02:00
//smat
2005-02-14 01:10:07 +01:00
template<typename C>
struct LA_traits_aux<NRSMat<C>, scalar_false> {
typedef C elementtype;
typedef NRMat<C> producttype;
typedef typename LA_traits<C>::normtype normtype;
2005-09-06 17:55:07 +02:00
static bool gencmp(const C *x, const C *y, int n) {for(int i=0; i<n; ++i) if(x[i]!=y[i]) return true; return false;}
static inline bool bigger(const C &x, const C &y) {return x>y;}
static inline bool smaller(const C &x, const C &y) {return x<y;}
static inline normtype norm (const NRSMat<C> &x) {return x.norm();}
static inline void axpy (NRSMat<C>&s, const NRSMat<C> &x, const C c) {s.axpy(c,x);}
2005-09-11 22:04:24 +02:00
static void put(int fd, const C &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions);}
static void get(int fd, C &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions);}
2005-02-14 01:10:07 +01:00
static void multiput(unsigned int n,int fd, const C *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].put(fd,dimensions);} \
static void multiget(unsigned int n,int fd, C *x, bool dimensions=1) {for(unsigned int i=0; i<n; ++i) x[i].get(fd,dimensions);} \
2004-03-17 04:07:21 +01:00
};
2005-02-14 01:10:07 +01:00
//the final traits class
template<typename C>
struct LA_traits : LA_traits_aux<C, typename isscalar<C>::scalar_type> {};
2004-03-17 04:07:21 +01:00
#endif