/* LA: linear algebra C++ interface library Copyright (C) 2008 Jiri Pittner or This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ //////////////////////////////////////////////////////////////////////////// //LA traits classes and generally needed includes #ifndef _LA_TRAITS_INCL #define _LA_TRAITS_INCL using namespace std; #include #include #include #include #include "laerror.h" #ifdef NONCBLAS #include "noncblas.h" #else extern "C" { #include "cblas.h" } #endif //forward declarations template class NRVec; template class NRMat; template class NRMat_from1; template class NRSMat; template class NRSMat_from1; template class SparseMat; //for general sortable classes template struct LA_sort_traits; template struct LA_sort_traits { static inline bool compare(T object, I i, I j) {return object.bigger(i,j);}; }; template struct LA_sort_traits { static inline bool compare(T object, I i, I j) {return object.smaller(i,j);}; }; //we will need to treat char and unsigned char as numbers in << and >> I/O operators template struct LA_traits_io { typedef C IOtype; }; template<> struct LA_traits_io { typedef int IOtype; }; template<> struct LA_traits_io { typedef unsigned int IOtype; }; //let's do some simple template metaprogramming and preprocessing //to keep the thing general and compact typedef class scalar_false {}; typedef class scalar_true {}; //default is non-scalar template class isscalar { public: typedef scalar_false scalar_type;}; //specializations #define SCALAR(X) \ template<>\ class isscalar {public: typedef scalar_true scalar_type;}; //declare what is scalar SCALAR(char) SCALAR(short) SCALAR(int) SCALAR(long) SCALAR(long long) SCALAR(unsigned char) SCALAR(unsigned short) SCALAR(unsigned int) SCALAR(unsigned long) SCALAR(unsigned long long) SCALAR(float) SCALAR(double) SCALAR(complex) SCALAR(complex) SCALAR(void *) #undef SCALAR //now declare the traits for scalars and for composed classes //NOTE! methods in traits classes have to be declared static, //since the class itself is never instantiated. //for performance, it can be also inlined at the same time template struct LA_traits_aux {}; //TRAITS SPECIALIZATIONS //complex scalars template struct LA_traits_aux, scalar_true> { typedef complex elementtype; typedef complex producttype; typedef C normtype; static inline bool gencmp(const complex *x, const complex *y, int n) {return memcmp(x,y,n*sizeof(complex));} static bool bigger(const complex &x, const complex &y) {laerror("complex comparison undefined"); return false;} static bool smaller(const complex &x, const complex &y) {laerror("complex comparison undefined"); return false;} static inline normtype norm (const complex &x) {return abs(x);} static inline void axpy (complex &s, const complex &x, const complex &c) {s+=x*c;} static inline void get(int fd, complex &x, bool dimensions=0, bool transp=0) {if(sizeof(complex)!=read(fd,&x,sizeof(complex))) laerror("read error");} static inline void put(int fd, const complex &x, bool dimensions=0, bool transp=0) {if(sizeof(complex)!=write(fd,&x,sizeof(complex))) laerror("write error");} static void multiget(unsigned int n,int fd, complex *x, bool dimensions=0){if((ssize_t)(n*sizeof(complex))!=read(fd,x,n*sizeof(complex))) laerror("read error");} static void multiput(unsigned int n, int fd, const complex *x, bool dimensions=0) {if((ssize_t)(n*sizeof(complex))!=write(fd,x,n*sizeof(complex))) laerror("write error");} static void copy(complex *dest, complex *src, unsigned int n) {memcpy(dest,src,n*sizeof(complex));} static void clear(complex *dest, unsigned int n) {memset(dest,0,n*sizeof(complex));} static void copyonwrite(complex &x) {}; }; //non-complex scalars template struct LA_traits_aux { typedef C elementtype; typedef C producttype; typedef C normtype; static inline bool gencmp(const C *x, const C *y, int n) {return memcmp(x,y,n*sizeof(C));} static inline bool bigger(const C &x, const C &y) {return x>y;} static inline bool smaller(const C &x, const C &y) {return x struct LA_traits; //forward declaration needed for template recursion #define generate_traits(X) \ template \ struct LA_traits_aux, scalar_false> { \ typedef C elementtype; \ typedef X producttype; \ typedef typename LA_traits::normtype normtype; \ static bool gencmp(const C *x, const C *y, int n) {for(int i=0; iy;} \ static inline bool smaller(const C &x, const C &y) {return x &x) {return x.norm();} \ static inline void axpy (X&s, const X &x, const C c) {s.axpy(c,x);} \ static void put(int fd, const C &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions,transp);} \ static void get(int fd, C &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions,transp);} \ static void multiput(unsigned int n,int fd, const C *x, bool dimensions=1) {for(unsigned int i=0; i &x) {x.copyonwrite();}\ }; //non-scalar types defined in this library generate_traits(NRMat) generate_traits(NRMat_from1) generate_traits(NRVec) generate_traits(SparseMat) #undef generate_traits //smat #define generate_traits_smat(X) \ template \ struct LA_traits_aux, scalar_false> { \ typedef C elementtype; \ typedef NRMat producttype; \ typedef typename LA_traits::normtype normtype; \ static bool gencmp(const C *x, const C *y, int n) {for(int i=0; iy;} \ static inline bool smaller(const C &x, const C &y) {return x &x) {return x.norm();} \ static inline void axpy (X&s, const X &x, const C c) {s.axpy(c,x);} \ static void put(int fd, const C &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions);} \ static void get(int fd, C &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions);} \ static void multiput(unsigned int n,int fd, const C *x, bool dimensions=1) {for(unsigned int i=0; i &x) {x.copyonwrite();} \ }; generate_traits_smat(NRSMat) generate_traits_smat(NRSMat_from1) //the final traits class template struct LA_traits : LA_traits_aux::scalar_type> {}; #endif