2008-02-26 14:55:23 +01:00
/*
LA : linear algebra C + + interface library
Copyright ( C ) 2008 Jiri Pittner < jiri . pittner @ jh - inst . cas . cz > or < jiri @ pittnerovi . com >
This program is free software : you can redistribute it and / or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation , either version 3 of the License , or
( at your option ) any later version .
This program is distributed in the hope that it will be useful ,
but WITHOUT ANY WARRANTY ; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE . See the
GNU General Public License for more details .
You should have received a copy of the GNU General Public License
along with this program . If not , see < http : //www.gnu.org/licenses/>.
*/
2009-10-08 16:01:15 +02:00
//
//for autotools
//
2010-01-07 17:10:12 +01:00
//#include "config.h" //this would force the user of the library to have config.h
2009-10-08 16:01:15 +02:00
2004-03-17 04:07:21 +01:00
////////////////////////////////////////////////////////////////////////////
2005-02-18 23:08:15 +01:00
//LA traits classes and generally needed includes
2004-03-17 04:07:21 +01:00
# ifndef _LA_TRAITS_INCL
# define _LA_TRAITS_INCL
2005-02-18 23:08:15 +01:00
2009-10-08 16:01:15 +02:00
2005-02-18 23:08:15 +01:00
# include <stdio.h>
# include <string.h>
# include <iostream>
2010-12-23 12:30:00 +01:00
# include <fstream>
2010-09-08 18:27:58 +02:00
# include <limits>
2005-02-18 23:08:15 +01:00
# include <complex>
2009-11-12 22:01:19 +01:00
//using namespace std;
# define complex std::complex
2005-02-18 23:08:15 +01:00
# include "laerror.h"
2010-06-25 17:28:19 +02:00
# include "cuda_la.h"
2005-09-06 17:55:07 +02:00
# ifdef NONCBLAS
# include "noncblas.h"
# else
2005-02-18 23:08:15 +01:00
extern " C " {
# include "cblas.h"
}
2005-09-06 17:55:07 +02:00
# endif
2005-02-18 23:08:15 +01:00
2009-10-08 16:01:15 +02:00
# ifdef NONCLAPACK
# include "noncblas.h"
# else
extern " C " {
# include "clapack.h"
}
# endif
2009-11-12 22:01:19 +01:00
namespace LA {
extern bool _LA_count_check ;
2005-02-18 23:08:15 +01:00
2005-02-14 01:10:07 +01:00
//forward declarations
template < typename C > class NRVec ;
template < typename C > class NRMat ;
2006-09-18 23:46:45 +02:00
template < typename C > class NRMat_from1 ;
2005-02-14 01:10:07 +01:00
template < typename C > class NRSMat ;
2006-09-18 23:46:45 +02:00
template < typename C > class NRSMat_from1 ;
2005-02-14 01:10:07 +01:00
template < typename C > class SparseMat ;
2009-11-12 22:01:19 +01:00
template < typename C > class SparseSMat ;
2010-12-23 12:30:00 +01:00
template < typename C > class CSRMat ;
2005-02-14 01:10:07 +01:00
2009-10-08 16:01:15 +02:00
typedef class { } Dummy_type ;
typedef class { } Dummy_type2 ;
//for components of complex numbers
//
template < typename C >
struct LA_traits_complex
{
typedef Dummy_type Component_type ;
typedef Dummy_type NRVec_Noncomplex_type ;
typedef Dummy_type NRMat_Noncomplex_type ;
typedef Dummy_type2 NRSMat_Noncomplex_type ;
} ;
# define SPECIALIZE_COMPLEX(T) \
template < > \
struct LA_traits_complex < complex < T > > \
{ \
typedef T Component_type ; \
typedef NRVec < T > NRVec_Noncomplex_type ; \
typedef NRMat < T > NRMat_Noncomplex_type ; \
typedef NRSMat < T > NRSMat_Noncomplex_type ; \
} ;
SPECIALIZE_COMPLEX ( double )
SPECIALIZE_COMPLEX ( complex < double > )
SPECIALIZE_COMPLEX ( float )
SPECIALIZE_COMPLEX ( complex < float > )
SPECIALIZE_COMPLEX ( char )
SPECIALIZE_COMPLEX ( unsigned char )
SPECIALIZE_COMPLEX ( short )
2009-11-12 22:01:19 +01:00
SPECIALIZE_COMPLEX ( unsigned short )
2009-10-08 16:01:15 +02:00
SPECIALIZE_COMPLEX ( int )
SPECIALIZE_COMPLEX ( unsigned int )
2009-11-12 22:01:19 +01:00
SPECIALIZE_COMPLEX ( long )
2009-10-08 16:01:15 +02:00
SPECIALIZE_COMPLEX ( unsigned long )
2009-11-12 22:01:19 +01:00
SPECIALIZE_COMPLEX ( long long )
SPECIALIZE_COMPLEX ( unsigned long long )
2009-10-08 16:01:15 +02:00
2006-04-01 14:58:57 +02:00
//for general sortable classes
template < typename T , typename I , int type > struct LA_sort_traits ;
template < typename T , typename I >
struct LA_sort_traits < T , I , 0 >
{
static inline bool compare ( T object , I i , I j ) { return object . bigger ( i , j ) ; } ;
} ;
template < typename T , typename I >
struct LA_sort_traits < T , I , 1 >
{
static inline bool compare ( T object , I i , I j ) { return object . smaller ( i , j ) ; } ;
} ;
2006-04-01 06:48:01 +02:00
//we will need to treat char and unsigned char as numbers in << and >> I/O operators
template < typename C >
struct LA_traits_io
{
typedef C IOtype ;
} ;
template < >
struct LA_traits_io < char >
{
typedef int IOtype ;
} ;
template < >
struct LA_traits_io < unsigned char >
{
typedef unsigned int IOtype ;
} ;
2005-02-14 01:10:07 +01:00
//let's do some simple template metaprogramming and preprocessing
//to keep the thing general and compact
2008-11-26 14:30:30 +01:00
class scalar_false { } ;
class scalar_true { } ;
2005-02-14 01:10:07 +01:00
//default is non-scalar
template < typename C >
2005-12-08 13:06:23 +01:00
class isscalar { public : typedef scalar_false scalar_type ; } ;
2005-02-14 01:10:07 +01:00
//specializations
# define SCALAR(X) \
2005-11-20 14:46:00 +01:00
template < > \
2009-10-08 16:01:15 +02:00
class isscalar < X > { public : typedef scalar_true scalar_type ; } ; \
template < > \
class isscalar < complex < X > > { public : typedef scalar_true scalar_type ; } ; \
template < > \
class isscalar < complex < complex < X > > > { public : typedef scalar_true scalar_type ; } ; \
2005-02-14 01:10:07 +01:00
//declare what is scalar
SCALAR ( char )
SCALAR ( short )
SCALAR ( int )
SCALAR ( long )
SCALAR ( long long )
SCALAR ( unsigned char )
SCALAR ( unsigned short )
SCALAR ( unsigned int )
SCALAR ( unsigned long )
SCALAR ( unsigned long long )
SCALAR ( float )
SCALAR ( double )
SCALAR ( void * )
# undef SCALAR
2009-10-08 16:01:15 +02:00
//declare this generically as traits for any unknown class
template < typename C , typename Scalar > struct LA_traits_aux
{
typedef Dummy_type normtype ;
} ;
2005-02-14 01:10:07 +01:00
2005-09-06 17:55:07 +02:00
//TRAITS SPECIALIZATIONS
2009-10-08 16:01:15 +02:00
////now declare the traits for scalars and for composed classes
////NOTE! methods in traits classes have to be declared static,
////since the class itself is never instantiated.
////for performance, it can be also inlined at the same time
//
2005-09-06 17:55:07 +02:00
2005-02-14 01:10:07 +01:00
//complex scalars
template < typename C >
struct LA_traits_aux < complex < C > , scalar_true > {
typedef complex < C > elementtype ;
typedef complex < C > producttype ;
typedef C normtype ;
2010-06-25 17:28:19 +02:00
typedef C realtype ;
typedef complex < C > complextype ;
2009-11-12 22:01:19 +01:00
static inline C sqrabs ( const complex < C > x ) { return x . real ( ) * x . real ( ) + x . imag ( ) * x . imag ( ) ; }
2005-09-06 17:55:07 +02:00
static inline bool gencmp ( const complex < C > * x , const complex < C > * y , int n ) { return memcmp ( x , y , n * sizeof ( complex < C > ) ) ; }
static bool bigger ( const complex < C > & x , const complex < C > & y ) { laerror ( " complex comparison undefined " ) ; return false ; }
static bool smaller ( const complex < C > & x , const complex < C > & y ) { laerror ( " complex comparison undefined " ) ; return false ; }
2009-11-12 22:01:19 +01:00
static inline normtype norm ( const complex < C > & x ) { return std : : abs ( x ) ; }
2005-09-06 17:55:07 +02:00
static inline void axpy ( complex < C > & s , const complex < C > & x , const complex < C > & c ) { s + = x * c ; }
2005-09-11 22:04:24 +02:00
static inline void get ( int fd , complex < C > & x , bool dimensions = 0 , bool transp = 0 ) { if ( sizeof ( complex < C > ) ! = read ( fd , & x , sizeof ( complex < C > ) ) ) laerror ( " read error " ) ; }
static inline void put ( int fd , const complex < C > & x , bool dimensions = 0 , bool transp = 0 ) { if ( sizeof ( complex < C > ) ! = write ( fd , & x , sizeof ( complex < C > ) ) ) laerror ( " write error " ) ; }
2010-12-23 12:30:00 +01:00
static void multiget ( size_t n , int fd , complex < C > * x , bool dimensions = 0 ) { ssize_t r = read ( fd , x , n * sizeof ( complex < C > ) ) ; if ( ( ssize_t ) ( n * sizeof ( complex < C > ) ) ! = r ) { std : : cout < < " read returned " < < r < < std : : endl ; laerror ( " read error " ) ; } }
static void multiput ( size_t n , int fd , const complex < C > * x , bool dimensions = 0 ) { ssize_t r = write ( fd , x , n * sizeof ( complex < C > ) ) ; if ( ( ssize_t ) ( n * sizeof ( complex < C > ) ) ! = r ) { std : : cout < < " write returned " < < r < < std : : endl ; laerror ( " write error " ) ; } }
2006-04-01 06:48:01 +02:00
static void copy ( complex < C > * dest , complex < C > * src , unsigned int n ) { memcpy ( dest , src , n * sizeof ( complex < C > ) ) ; }
2006-04-09 23:07:54 +02:00
static void clear ( complex < C > * dest , unsigned int n ) { memset ( dest , 0 , n * sizeof ( complex < C > ) ) ; }
2007-11-29 14:52:31 +01:00
static void copyonwrite ( complex < C > & x ) { } ;
2009-11-12 22:01:19 +01:00
static void clearme ( complex < C > & x ) { x = 0 ; } ;
2010-12-23 12:30:00 +01:00
static void deallocate ( complex < C > & x ) { } ;
2010-01-11 11:12:28 +01:00
static inline complex < C > conjugate ( const complex < C > & x ) { return complex < C > ( x . real ( ) , - x . imag ( ) ) ; } ;
static inline C realpart ( const complex < C > & x ) { return x . real ( ) ; }
static inline C imagpart ( const complex < C > & x ) { return x . imag ( ) ; }
2005-02-14 01:10:07 +01:00
} ;
//non-complex scalars
template < typename C >
struct LA_traits_aux < C , scalar_true > {
2004-03-17 04:07:21 +01:00
typedef C elementtype ;
typedef C producttype ;
2005-02-14 01:10:07 +01:00
typedef C normtype ;
2010-06-25 17:28:19 +02:00
typedef C realtype ;
typedef complex < C > complextype ;
2009-11-12 22:01:19 +01:00
static inline C sqrabs ( const C x ) { return x * x ; }
2005-09-06 17:55:07 +02:00
static inline bool gencmp ( const C * x , const C * y , int n ) { return memcmp ( x , y , n * sizeof ( C ) ) ; }
static inline bool bigger ( const C & x , const C & y ) { return x > y ; }
static inline bool smaller ( const C & x , const C & y ) { return x < y ; }
2009-11-12 22:01:19 +01:00
static inline normtype norm ( const C & x ) { return std : : abs ( x ) ; }
2005-09-06 17:55:07 +02:00
static inline void axpy ( C & s , const C & x , const C & c ) { s + = x * c ; }
2005-09-11 22:04:24 +02:00
static inline void put ( int fd , const C & x , bool dimensions = 0 , bool transp = 0 ) { if ( sizeof ( C ) ! = write ( fd , & x , sizeof ( C ) ) ) laerror ( " write error " ) ; }
static inline void get ( int fd , C & x , bool dimensions = 0 , bool transp = 0 ) { if ( sizeof ( C ) ! = read ( fd , & x , sizeof ( C ) ) ) laerror ( " read error " ) ; }
2010-12-23 12:30:00 +01:00
static void multiget ( size_t n , int fd , C * x , bool dimensions = 0 ) { ssize_t r = read ( fd , x , n * sizeof ( C ) ) ; if ( ( ssize_t ) ( n * sizeof ( C ) ) ! = r ) { std : : cout < < " read returned " < < r < < std : : endl ; laerror ( " read error " ) ; } }
static void multiput ( size_t n , int fd , const C * x , bool dimensions = 0 ) { ssize_t r = write ( fd , x , n * sizeof ( C ) ) ; if ( ( ssize_t ) ( n * sizeof ( C ) ) ! = r ) { std : : cout < < " write returned " < < r < < std : : endl ; laerror ( " write error " ) ; } }
2006-04-01 06:48:01 +02:00
static void copy ( C * dest , C * src , unsigned int n ) { memcpy ( dest , src , n * sizeof ( C ) ) ; }
2006-04-09 23:07:54 +02:00
static void clear ( C * dest , unsigned int n ) { memset ( dest , 0 , n * sizeof ( C ) ) ; }
2007-11-29 14:52:31 +01:00
static void copyonwrite ( C & x ) { } ;
2010-12-23 12:30:00 +01:00
static void clearme ( C & x ) { x = 0 ; } ;
static void deallocate ( C & x ) { } ;
2010-01-11 11:12:28 +01:00
static inline C conjugate ( const C & x ) { return x ; } ;
static inline C realpart ( const C & x ) { return x ; }
static inline C imagpart ( const C & x ) { return 0 ; }
2004-03-17 04:07:21 +01:00
} ;
2005-09-06 17:55:07 +02:00
//non-scalars except smat
2005-02-14 01:10:07 +01:00
template < typename C >
struct LA_traits ; //forward declaration needed for template recursion
# define generate_traits(X) \
template < typename C > \
struct LA_traits_aux < X < C > , scalar_false > { \
typedef C elementtype ; \
typedef X < C > producttype ; \
typedef typename LA_traits < C > : : normtype normtype ; \
2010-06-25 17:28:19 +02:00
typedef X < typename LA_traits < C > : : realtype > realtype ; \
typedef X < typename LA_traits < C > : : complextype > complextype ; \
2005-09-06 17:55:07 +02:00
static bool gencmp ( const C * x , const C * y , int n ) { for ( int i = 0 ; i < n ; + + i ) if ( x [ i ] ! = y [ i ] ) return true ; return false ; } \
static inline bool bigger ( const C & x , const C & y ) { return x > y ; } \
static inline bool smaller ( const C & x , const C & y ) { return x < y ; } \
static inline normtype norm ( const X < C > & x ) { return x . norm ( ) ; } \
static inline void axpy ( X < C > & s , const X < C > & x , const C c ) { s . axpy ( c , x ) ; } \
2010-01-07 17:10:12 +01:00
static void put ( int fd , const X < C > & x , bool dimensions = 1 , bool transp = 0 ) { x . put ( fd , dimensions , transp ) ; } \
static void get ( int fd , X < C > & x , bool dimensions = 1 , bool transp = 0 ) { x . get ( fd , dimensions , transp ) ; } \
2010-12-23 12:30:00 +01:00
static void multiput ( size_t n , int fd , const X < C > * x , bool dimensions = 1 ) { for ( size_t i = 0 ; i < n ; + + i ) x [ i ] . put ( fd , dimensions ) ; } \
static void multiget ( size_t n , int fd , X < C > * x , bool dimensions = 1 ) { for ( size_t i = 0 ; i < n ; + + i ) x [ i ] . get ( fd , dimensions ) ; } \
2006-04-01 06:48:01 +02:00
static void copy ( C * dest , C * src , unsigned int n ) { for ( unsigned int i = 0 ; i < n ; + + i ) dest [ i ] = src [ i ] ; } \
2006-04-10 18:08:42 +02:00
static void clear ( C * dest , unsigned int n ) { for ( unsigned int i = 0 ; i < n ; + + i ) dest [ i ] . clear ( ) ; } \
2007-11-29 14:52:31 +01:00
static void copyonwrite ( X < C > & x ) { x . copyonwrite ( ) ; } \
2009-11-12 22:01:19 +01:00
static void clearme ( X < C > & x ) { x . clear ( ) ; } \
2010-12-23 12:30:00 +01:00
static void deallocate ( X < C > & x ) { x . dealloc ( ) ; } \
2004-03-17 04:07:21 +01:00
} ;
2005-02-14 01:10:07 +01:00
//non-scalar types defined in this library
generate_traits ( NRMat )
2006-09-18 23:46:45 +02:00
generate_traits ( NRMat_from1 )
2005-02-14 01:10:07 +01:00
generate_traits ( NRVec )
generate_traits ( SparseMat )
2009-11-12 22:01:19 +01:00
generate_traits ( SparseSMat ) //product leading to non-symmetric result not implemented
2010-12-23 12:30:00 +01:00
generate_traits ( CSRMat )
2005-02-14 01:10:07 +01:00
# undef generate_traits
2005-09-06 17:55:07 +02:00
//smat
2006-09-18 23:46:45 +02:00
# define generate_traits_smat(X) \
template < typename C > \
struct LA_traits_aux < X < C > , scalar_false > { \
typedef C elementtype ; \
typedef NRMat < C > producttype ; \
typedef typename LA_traits < C > : : normtype normtype ; \
2010-06-25 17:28:19 +02:00
typedef X < typename LA_traits < C > : : realtype > realtype ; \
typedef X < typename LA_traits < C > : : complextype > complextype ; \
2006-09-18 23:46:45 +02:00
static bool gencmp ( const C * x , const C * y , int n ) { for ( int i = 0 ; i < n ; + + i ) if ( x [ i ] ! = y [ i ] ) return true ; return false ; } \
static inline bool bigger ( const C & x , const C & y ) { return x > y ; } \
static inline bool smaller ( const C & x , const C & y ) { return x < y ; } \
static inline normtype norm ( const X < C > & x ) { return x . norm ( ) ; } \
static inline void axpy ( X < C > & s , const X < C > & x , const C c ) { s . axpy ( c , x ) ; } \
2010-01-07 17:10:12 +01:00
static void put ( int fd , const X < C > & x , bool dimensions = 1 , bool transp = 0 ) { x . put ( fd , dimensions ) ; } \
static void get ( int fd , X < C > & x , bool dimensions = 1 , bool transp = 0 ) { x . get ( fd , dimensions ) ; } \
2010-12-23 12:30:00 +01:00
static void multiput ( size_t n , int fd , const X < C > * x , bool dimensions = 1 ) { for ( size_t i = 0 ; i < n ; + + i ) x [ i ] . put ( fd , dimensions ) ; } \
static void multiget ( size_t n , int fd , X < C > * x , bool dimensions = 1 ) { for ( size_t i = 0 ; i < n ; + + i ) x [ i ] . get ( fd , dimensions ) ; } \
2006-09-18 23:46:45 +02:00
static void copy ( C * dest , C * src , unsigned int n ) { for ( unsigned int i = 0 ; i < n ; + + i ) dest [ i ] = src [ i ] ; } \
static void clear ( C * dest , unsigned int n ) { for ( unsigned int i = 0 ; i < n ; + + i ) dest [ i ] . clear ( ) ; } \
2007-11-29 14:52:31 +01:00
static void copyonwrite ( X < C > & x ) { x . copyonwrite ( ) ; } \
2009-11-12 22:01:19 +01:00
static void clearme ( X < C > & x ) { x . clear ( ) ; } \
2010-12-23 12:30:00 +01:00
static void deallocate ( X < C > & x ) { x . dealloc ( ) ; } \
2004-03-17 04:07:21 +01:00
} ;
2006-09-18 23:46:45 +02:00
generate_traits_smat ( NRSMat )
generate_traits_smat ( NRSMat_from1 )
2004-03-17 04:07:21 +01:00
2009-10-08 16:01:15 +02:00
2005-02-14 01:10:07 +01:00
//the final traits class
template < typename C >
struct LA_traits : LA_traits_aux < C , typename isscalar < C > : : scalar_type > { } ;
2004-03-17 04:07:21 +01:00
2009-11-12 22:01:19 +01:00
} //namespace
2004-03-17 04:07:21 +01:00
# endif