/* LA: linear algebra C++ interface library Copyright (C) 2008 Jiri Pittner or This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ // //for autotools // #include "config.h" //////////////////////////////////////////////////////////////////////////// //LA traits classes and generally needed includes #ifndef _LA_TRAITS_INCL #define _LA_TRAITS_INCL #include #include #include #include //using namespace std; #define complex std::complex #include "laerror.h" #ifdef NONCBLAS #include "noncblas.h" #else extern "C" { #include "cblas.h" } #endif #ifdef NONCLAPACK #include "noncblas.h" #else extern "C" { #include "clapack.h" } #endif namespace LA { extern bool _LA_count_check; //forward declarations template class NRVec; template class NRMat; template class NRMat_from1; template class NRSMat; template class NRSMat_from1; template class SparseMat; template class SparseSMat; typedef class {} Dummy_type; typedef class {} Dummy_type2; //for components of complex numbers // template struct LA_traits_complex { typedef Dummy_type Component_type; typedef Dummy_type NRVec_Noncomplex_type; typedef Dummy_type NRMat_Noncomplex_type; typedef Dummy_type2 NRSMat_Noncomplex_type; }; #define SPECIALIZE_COMPLEX(T) \ template<> \ struct LA_traits_complex > \ { \ typedef T Component_type; \ typedef NRVec NRVec_Noncomplex_type; \ typedef NRMat NRMat_Noncomplex_type; \ typedef NRSMat NRSMat_Noncomplex_type; \ }; SPECIALIZE_COMPLEX(double) SPECIALIZE_COMPLEX(complex) SPECIALIZE_COMPLEX(float) SPECIALIZE_COMPLEX(complex) SPECIALIZE_COMPLEX(char) SPECIALIZE_COMPLEX(unsigned char) SPECIALIZE_COMPLEX(short) SPECIALIZE_COMPLEX(unsigned short) SPECIALIZE_COMPLEX(int) SPECIALIZE_COMPLEX(unsigned int) SPECIALIZE_COMPLEX(long) SPECIALIZE_COMPLEX(unsigned long) SPECIALIZE_COMPLEX(long long) SPECIALIZE_COMPLEX(unsigned long long) //for general sortable classes template struct LA_sort_traits; template struct LA_sort_traits { static inline bool compare(T object, I i, I j) {return object.bigger(i,j);}; }; template struct LA_sort_traits { static inline bool compare(T object, I i, I j) {return object.smaller(i,j);}; }; //we will need to treat char and unsigned char as numbers in << and >> I/O operators template struct LA_traits_io { typedef C IOtype; }; template<> struct LA_traits_io { typedef int IOtype; }; template<> struct LA_traits_io { typedef unsigned int IOtype; }; //let's do some simple template metaprogramming and preprocessing //to keep the thing general and compact class scalar_false {}; class scalar_true {}; //default is non-scalar template class isscalar { public: typedef scalar_false scalar_type;}; //specializations #define SCALAR(X) \ template<>\ class isscalar {public: typedef scalar_true scalar_type;};\ template<>\ class isscalar > {public: typedef scalar_true scalar_type;};\ template<>\ class isscalar > > {public: typedef scalar_true scalar_type;};\ //declare what is scalar SCALAR(char) SCALAR(short) SCALAR(int) SCALAR(long) SCALAR(long long) SCALAR(unsigned char) SCALAR(unsigned short) SCALAR(unsigned int) SCALAR(unsigned long) SCALAR(unsigned long long) SCALAR(float) SCALAR(double) SCALAR(void *) #undef SCALAR //declare this generically as traits for any unknown class template struct LA_traits_aux { typedef Dummy_type normtype; }; //TRAITS SPECIALIZATIONS ////now declare the traits for scalars and for composed classes ////NOTE! methods in traits classes have to be declared static, ////since the class itself is never instantiated. ////for performance, it can be also inlined at the same time // //complex scalars template struct LA_traits_aux, scalar_true> { typedef complex elementtype; typedef complex producttype; typedef C normtype; static inline C sqrabs(const complex x) { return x.real()*x.real()+x.imag()*x.imag();} static inline bool gencmp(const complex *x, const complex *y, int n) {return memcmp(x,y,n*sizeof(complex));} static bool bigger(const complex &x, const complex &y) {laerror("complex comparison undefined"); return false;} static bool smaller(const complex &x, const complex &y) {laerror("complex comparison undefined"); return false;} static inline normtype norm (const complex &x) {return std::abs(x);} static inline void axpy (complex &s, const complex &x, const complex &c) {s+=x*c;} static inline void get(int fd, complex &x, bool dimensions=0, bool transp=0) {if(sizeof(complex)!=read(fd,&x,sizeof(complex))) laerror("read error");} static inline void put(int fd, const complex &x, bool dimensions=0, bool transp=0) {if(sizeof(complex)!=write(fd,&x,sizeof(complex))) laerror("write error");} static void multiget(unsigned int n,int fd, complex *x, bool dimensions=0){if((ssize_t)(n*sizeof(complex))!=read(fd,x,n*sizeof(complex))) laerror("read error");} static void multiput(unsigned int n, int fd, const complex *x, bool dimensions=0) {if((ssize_t)(n*sizeof(complex))!=write(fd,x,n*sizeof(complex))) laerror("write error");} static void copy(complex *dest, complex *src, unsigned int n) {memcpy(dest,src,n*sizeof(complex));} static void clear(complex *dest, unsigned int n) {memset(dest,0,n*sizeof(complex));} static void copyonwrite(complex &x) {}; static void clearme(complex &x) {x=0;}; }; //non-complex scalars template struct LA_traits_aux { typedef C elementtype; typedef C producttype; typedef C normtype; static inline C sqrabs(const C x) { return x*x;} static inline bool gencmp(const C *x, const C *y, int n) {return memcmp(x,y,n*sizeof(C));} static inline bool bigger(const C &x, const C &y) {return x>y;} static inline bool smaller(const C &x, const C &y) {return x &x) {x=0;}; }; //non-scalars except smat template struct LA_traits; //forward declaration needed for template recursion #define generate_traits(X) \ template \ struct LA_traits_aux, scalar_false> { \ typedef C elementtype; \ typedef X producttype; \ typedef typename LA_traits::normtype normtype; \ static bool gencmp(const C *x, const C *y, int n) {for(int i=0; iy;} \ static inline bool smaller(const C &x, const C &y) {return x &x) {return x.norm();} \ static inline void axpy (X&s, const X &x, const C c) {s.axpy(c,x);} \ static void put(int fd, const C &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions,transp);} \ static void get(int fd, C &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions,transp);} \ static void multiput(unsigned int n,int fd, const C *x, bool dimensions=1) {for(unsigned int i=0; i &x) {x.copyonwrite();}\ static void clearme(X &x) {x.clear();}\ }; //non-scalar types defined in this library generate_traits(NRMat) generate_traits(NRMat_from1) generate_traits(NRVec) generate_traits(SparseMat) generate_traits(SparseSMat) //product leading to non-symmetric result not implemented #undef generate_traits //smat #define generate_traits_smat(X) \ template \ struct LA_traits_aux, scalar_false> { \ typedef C elementtype; \ typedef NRMat producttype; \ typedef typename LA_traits::normtype normtype; \ static bool gencmp(const C *x, const C *y, int n) {for(int i=0; iy;} \ static inline bool smaller(const C &x, const C &y) {return x &x) {return x.norm();} \ static inline void axpy (X&s, const X &x, const C c) {s.axpy(c,x);} \ static void put(int fd, const C &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions);} \ static void get(int fd, C &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions);} \ static void multiput(unsigned int n,int fd, const C *x, bool dimensions=1) {for(unsigned int i=0; i &x) {x.copyonwrite();} \ static void clearme(X &x) {x.clear();} \ }; generate_traits_smat(NRSMat) generate_traits_smat(NRSMat_from1) //the final traits class template struct LA_traits : LA_traits_aux::scalar_type> {}; }//namespace #endif