LA_library/la_traits.h

409 lines
14 KiB
C
Raw Normal View History

2008-02-26 14:55:23 +01:00
/*
LA: linear algebra C++ interface library
Copyright (C) 2008 Jiri Pittner <jiri.pittner@jh-inst.cas.cz> or <jiri@pittnerovi.com>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
2009-10-08 16:01:15 +02:00
//
//for autotools
//
2010-01-07 17:10:12 +01:00
//#include "config.h" //this would force the user of the library to have config.h
2009-10-08 16:01:15 +02:00
2004-03-17 04:07:21 +01:00
////////////////////////////////////////////////////////////////////////////
2005-02-18 23:08:15 +01:00
//LA traits classes and generally needed includes
2004-03-17 04:07:21 +01:00
#ifndef _LA_TRAITS_INCL
#define _LA_TRAITS_INCL
2005-02-18 23:08:15 +01:00
2009-10-08 16:01:15 +02:00
2005-02-18 23:08:15 +01:00
#include <stdio.h>
#include <string.h>
#include <iostream>
2010-12-23 12:30:00 +01:00
#include <fstream>
2010-09-08 18:27:58 +02:00
#include <limits>
2005-02-18 23:08:15 +01:00
#include <complex>
2013-11-04 15:56:39 +01:00
#include <unistd.h>
2009-11-12 22:01:19 +01:00
//using namespace std;
2005-02-18 23:08:15 +01:00
#include "laerror.h"
2010-06-25 17:28:19 +02:00
#include "cuda_la.h"
2005-09-06 17:55:07 +02:00
#ifdef NONCBLAS
#include "noncblas.h"
#else
2005-02-18 23:08:15 +01:00
extern "C" {
2019-11-13 23:22:25 +01:00
#ifdef HAS_MKL
#include "mkl_cblas.h"
#else
2005-02-18 23:08:15 +01:00
#include "cblas.h"
2019-11-13 23:22:25 +01:00
#endif
2005-02-18 23:08:15 +01:00
}
2005-09-06 17:55:07 +02:00
#endif
2005-02-18 23:08:15 +01:00
2009-10-08 16:01:15 +02:00
#ifdef NONCLAPACK
#include "noncblas.h"
#else
extern "C" {
2013-11-04 15:56:39 +01:00
#include "atlas/clapack.h"
2009-10-08 16:01:15 +02:00
}
#endif
2009-11-12 22:01:19 +01:00
namespace LA {
extern bool _LA_count_check;
2005-02-18 23:08:15 +01:00
2005-02-14 01:10:07 +01:00
//forward declarations
template<typename C> class NRVec;
template<typename C> class NRMat;
2006-09-18 23:46:45 +02:00
template<typename C> class NRMat_from1;
2005-02-14 01:10:07 +01:00
template<typename C> class NRSMat;
2006-09-18 23:46:45 +02:00
template<typename C> class NRSMat_from1;
2005-02-14 01:10:07 +01:00
template<typename C> class SparseMat;
2009-11-12 22:01:19 +01:00
template<typename C> class SparseSMat;
2010-12-23 12:30:00 +01:00
template<typename C> class CSRMat;
2005-02-14 01:10:07 +01:00
2021-04-21 15:04:37 +02:00
//trick to allow real and imag part of complex as l-values
template<typename T>
T &real(std::complex<T> &c) {
return reinterpret_cast<T*>(&c)[0];
}
template<typename T>
T &imag(std::complex<T> &c) {
return reinterpret_cast<T*>(&c)[1];
}
//
2009-10-08 16:01:15 +02:00
typedef class {} Dummy_type;
typedef class {} Dummy_type2;
//for components of complex numbers
//
template<typename C>
struct LA_traits_complex
{
typedef Dummy_type Component_type;
typedef Dummy_type NRVec_Noncomplex_type;
typedef Dummy_type NRMat_Noncomplex_type;
typedef Dummy_type2 NRSMat_Noncomplex_type;
};
#define SPECIALIZE_COMPLEX(T) \
template<> \
2021-04-21 15:04:37 +02:00
struct LA_traits_complex<std::complex<T> > \
2009-10-08 16:01:15 +02:00
{ \
typedef T Component_type; \
typedef NRVec<T> NRVec_Noncomplex_type; \
typedef NRMat<T> NRMat_Noncomplex_type; \
typedef NRSMat<T> NRSMat_Noncomplex_type; \
};
SPECIALIZE_COMPLEX(double)
2021-04-21 15:04:37 +02:00
SPECIALIZE_COMPLEX(std::complex<double>)
2009-10-08 16:01:15 +02:00
SPECIALIZE_COMPLEX(float)
2021-04-21 15:04:37 +02:00
SPECIALIZE_COMPLEX(std::complex<float>)
2009-10-08 16:01:15 +02:00
SPECIALIZE_COMPLEX(char)
SPECIALIZE_COMPLEX(unsigned char)
SPECIALIZE_COMPLEX(short)
2009-11-12 22:01:19 +01:00
SPECIALIZE_COMPLEX(unsigned short)
2009-10-08 16:01:15 +02:00
SPECIALIZE_COMPLEX(int)
SPECIALIZE_COMPLEX(unsigned int)
2009-11-12 22:01:19 +01:00
SPECIALIZE_COMPLEX(long)
2009-10-08 16:01:15 +02:00
SPECIALIZE_COMPLEX(unsigned long)
2009-11-12 22:01:19 +01:00
SPECIALIZE_COMPLEX(long long)
SPECIALIZE_COMPLEX(unsigned long long)
2009-10-08 16:01:15 +02:00
2006-04-01 14:58:57 +02:00
//for general sortable classes
template<typename T, typename I, int type> struct LA_sort_traits;
template<typename T, typename I>
struct LA_sort_traits<T,I,0>
{
static inline bool compare(T object, I i, I j) {return object.bigger(i,j);};
};
template<typename T, typename I>
struct LA_sort_traits<T,I,1>
{
static inline bool compare(T object, I i, I j) {return object.smaller(i,j);};
};
2006-04-01 06:48:01 +02:00
//we will need to treat char and unsigned char as numbers in << and >> I/O operators
template<typename C>
struct LA_traits_io
{
typedef C IOtype;
};
template<>
struct LA_traits_io<char>
{
typedef int IOtype;
};
template<>
struct LA_traits_io<unsigned char>
{
typedef unsigned int IOtype;
};
2005-02-14 01:10:07 +01:00
//let's do some simple template metaprogramming and preprocessing
//to keep the thing general and compact
2008-11-26 14:30:30 +01:00
class scalar_false {};
class scalar_true {};
2005-02-14 01:10:07 +01:00
//default is non-scalar
template<typename C>
2005-12-08 13:06:23 +01:00
class isscalar { public: typedef scalar_false scalar_type;};
2005-02-14 01:10:07 +01:00
//specializations
#define SCALAR(X) \
2005-11-20 14:46:00 +01:00
template<>\
2009-10-08 16:01:15 +02:00
class isscalar<X> {public: typedef scalar_true scalar_type;};\
template<>\
2021-04-21 15:04:37 +02:00
class isscalar<std::complex<X> > {public: typedef scalar_true scalar_type;};\
2009-10-08 16:01:15 +02:00
template<>\
2021-04-21 15:04:37 +02:00
class isscalar<std::complex<std::complex<X> > > {public: typedef scalar_true scalar_type;};\
2009-10-08 16:01:15 +02:00
2005-02-14 01:10:07 +01:00
//declare what is scalar
SCALAR(char)
SCALAR(short)
SCALAR(int)
SCALAR(long)
SCALAR(long long)
SCALAR(unsigned char)
SCALAR(unsigned short)
SCALAR(unsigned int)
SCALAR(unsigned long)
SCALAR(unsigned long long)
SCALAR(float)
SCALAR(double)
SCALAR(void *)
#undef SCALAR
2009-10-08 16:01:15 +02:00
//declare this generically as traits for any unknown class
template<typename C, typename Scalar> struct LA_traits_aux
{
typedef Dummy_type normtype;
};
2005-02-14 01:10:07 +01:00
2005-09-06 17:55:07 +02:00
//TRAITS SPECIALIZATIONS
2009-10-08 16:01:15 +02:00
////now declare the traits for scalars and for composed classes
////NOTE! methods in traits classes have to be declared static,
////since the class itself is never instantiated.
////for performance, it can be also inlined at the same time
//
2005-09-06 17:55:07 +02:00
2005-02-14 01:10:07 +01:00
//complex scalars
template<typename C>
2021-04-21 15:04:37 +02:00
struct LA_traits_aux<std::complex<C>, scalar_true> {
typedef std::complex<C> elementtype;
typedef std::complex<C> producttype;
2005-02-14 01:10:07 +01:00
typedef C normtype;
2010-06-25 17:28:19 +02:00
typedef C realtype;
2021-04-21 15:04:37 +02:00
typedef std::complex<C> complextype;
static inline C sqrabs(const std::complex<C> x) { return x.real()*x.real()+x.imag()*x.imag();}
static inline bool gencmp(const std::complex<C> *x, const std::complex<C> *y, size_t n) {return memcmp(x,y,n*sizeof(std::complex<C>));}
static bool bigger(const std::complex<C> &x, const std::complex<C> &y) {laerror("std::complex comparison undefined"); return false;}
static bool smaller(const std::complex<C> &x, const std::complex<C> &y) {laerror("std::complex comparison undefined"); return false;}
static inline normtype norm (const std::complex<C> &x) {return std::abs(x);}
static inline void axpy (std::complex<C> &s, const std::complex<C> &x, const std::complex<C> &c) {s+=x*c;}
static inline void get(int fd, std::complex<C> &x, bool dimensions=0, bool transp=0) {if(sizeof(std::complex<C>)!=read(fd,&x,sizeof(std::complex<C>))) laerror("read error");}
static inline void put(int fd, const std::complex<C> &x, bool dimensions=0, bool transp=0) {if(sizeof(std::complex<C>)!=write(fd,&x,sizeof(std::complex<C>))) laerror("write error");}
static void multiget(size_t n,int fd, std::complex<C> *x, bool dimensions=0)
2011-01-18 15:37:05 +01:00
{
size_t total=0;
2021-04-21 15:04:37 +02:00
size_t system_limit = (1L<<30)/sizeof(std::complex<C>); //do not expect too much from the system and read at most 1GB at once
2011-01-18 15:37:05 +01:00
ssize_t r;
2013-11-04 15:56:39 +01:00
size_t nn;
2011-01-18 15:37:05 +01:00
do{
2021-04-21 15:04:37 +02:00
r=read(fd,x+total,nn=(n-total > system_limit ? system_limit : n-total)*sizeof(std::complex<C>));
2013-11-04 15:56:39 +01:00
if(r<0 || r==0 && nn!=0 ) {std::cout<<"read returned "<<r<<" perror "<<strerror(errno) <<std::endl; laerror("read error");}
2021-04-21 15:04:37 +02:00
else total += r/sizeof(std::complex<C>);
if(r%sizeof(std::complex<C>)) laerror("read error 2");
2011-01-18 15:37:05 +01:00
}
while(total < n);
}
2021-04-21 15:04:37 +02:00
static void multiput(size_t n, int fd, const std::complex<C> *x, bool dimensions=0)
2011-01-18 15:37:05 +01:00
{
size_t total=0;
2021-04-21 15:04:37 +02:00
size_t system_limit = (1L<<30)/sizeof(std::complex<C>); //do not expect too much from the system and write at most 1GB at once
2011-01-18 15:37:05 +01:00
ssize_t r;
2013-11-04 15:56:39 +01:00
size_t nn;
2011-01-18 15:37:05 +01:00
do{
2021-04-21 15:04:37 +02:00
r=write(fd,x+total,nn=(n-total > system_limit ? system_limit : n-total)*sizeof(std::complex<C>));
2013-11-04 15:56:39 +01:00
if(r<0 || r==0 && nn!=0 ) {std::cout<<"write returned "<<r<<" perror "<<strerror(errno) <<std::endl; laerror("write error");}
2021-04-21 15:04:37 +02:00
else total += r/sizeof(std::complex<C>);
if(r%sizeof(std::complex<C>)) laerror("write error 2");
2011-01-18 15:37:05 +01:00
}
while(total < n);
}
2021-04-21 15:04:37 +02:00
static void copy(std::complex<C> *dest, std::complex<C> *src, size_t n) {memcpy(dest,src,n*sizeof(std::complex<C>));}
static void clear(std::complex<C> *dest, size_t n) {memset(dest,0,n*sizeof(std::complex<C>));}
static void copyonwrite(std::complex<C> &x) {};
static void clearme(std::complex<C> &x) {x=0;};
static void deallocate(std::complex<C> &x) {};
static inline std::complex<C> conjugate(const std::complex<C> &x) {return std::complex<C>(x.real(),-x.imag());};
static inline C realpart(const std::complex<C> &x) {return x.real();}
static inline C imagpart(const std::complex<C> &x) {return x.imag();}
2005-02-14 01:10:07 +01:00
};
2011-01-18 15:37:05 +01:00
2005-02-14 01:10:07 +01:00
//non-complex scalars
template<typename C>
struct LA_traits_aux<C, scalar_true> {
2004-03-17 04:07:21 +01:00
typedef C elementtype;
typedef C producttype;
2005-02-14 01:10:07 +01:00
typedef C normtype;
2010-06-25 17:28:19 +02:00
typedef C realtype;
2021-04-21 15:04:37 +02:00
typedef std::complex<C> complextype;
2009-11-12 22:01:19 +01:00
static inline C sqrabs(const C x) { return x*x;}
2013-11-04 15:56:39 +01:00
static inline bool gencmp(const C *x, const C *y, size_t n) {return memcmp(x,y,n*sizeof(C));}
2005-09-06 17:55:07 +02:00
static inline bool bigger(const C &x, const C &y) {return x>y;}
static inline bool smaller(const C &x, const C &y) {return x<y;}
2009-11-12 22:01:19 +01:00
static inline normtype norm (const C &x) {return std::abs(x);}
2005-09-06 17:55:07 +02:00
static inline void axpy (C &s, const C &x, const C &c) {s+=x*c;}
2005-09-11 22:04:24 +02:00
static inline void put(int fd, const C &x, bool dimensions=0, bool transp=0) {if(sizeof(C)!=write(fd,&x,sizeof(C))) laerror("write error");}
static inline void get(int fd, C &x, bool dimensions=0, bool transp=0) {if(sizeof(C)!=read(fd,&x,sizeof(C))) laerror("read error");}
2011-01-18 15:37:05 +01:00
static void multiget(size_t n,int fd, C *x, bool dimensions=0)
{
size_t total=0;
2012-02-23 16:27:05 +01:00
size_t system_limit = (1L<<30)/sizeof(C); //do not expect too much from the system and read at most 1GB at once
2011-01-18 15:37:05 +01:00
ssize_t r;
2013-11-04 15:56:39 +01:00
size_t nn;
2011-01-18 15:37:05 +01:00
do{
2013-11-04 15:56:39 +01:00
r=read(fd,x+total,nn=(n-total > system_limit ? system_limit : n-total)*sizeof(C));
if(r<0 || r==0 && nn!=0 ) {std::cout<<"read returned "<<r<<" perror "<<strerror(errno) <<std::endl; laerror("read error");}
2011-01-18 15:37:05 +01:00
else total += r/sizeof(C);
if(r%sizeof(C)) laerror("read error 2");
}
while(total < n);
}
static void multiput(size_t n, int fd, const C *x, bool dimensions=0)
{
size_t total=0;
2012-02-23 16:27:05 +01:00
size_t system_limit = (1L<<30)/sizeof(C); //do not expect too much from the system and write at most 1GB at once
2011-01-18 15:37:05 +01:00
ssize_t r;
2013-11-04 15:56:39 +01:00
size_t nn;
2011-01-18 15:37:05 +01:00
do{
2013-11-04 15:56:39 +01:00
r=write(fd,x+total,nn=(n-total > system_limit ? system_limit : n-total)*sizeof(C));
if(r<0 || r==0 && nn!=0 ) {std::cout<<"write returned "<<r<<" perror "<<strerror(errno) <<std::endl; laerror("write error");}
2011-01-18 15:37:05 +01:00
else total += r/sizeof(C);
if(r%sizeof(C)) laerror("write error 2");
}
while(total < n);
}
2013-11-04 15:56:39 +01:00
static void copy(C *dest, C *src, size_t n) {memcpy(dest,src,n*sizeof(C));}
static void clear(C *dest, size_t n) {memset(dest,0,n*sizeof(C));}
2007-11-29 14:52:31 +01:00
static void copyonwrite(C &x) {};
2010-12-23 12:30:00 +01:00
static void clearme(C &x) {x=0;};
static void deallocate(C &x) {};
2010-01-11 11:12:28 +01:00
static inline C conjugate(const C &x) {return x;};
static inline C realpart(const C &x) {return x;}
static inline C imagpart(const C &x) {return 0;}
2004-03-17 04:07:21 +01:00
};
2005-09-06 17:55:07 +02:00
//non-scalars except smat
2005-02-14 01:10:07 +01:00
template<typename C>
struct LA_traits; //forward declaration needed for template recursion
#define generate_traits(X) \
template<typename C> \
struct LA_traits_aux<X<C>, scalar_false> { \
typedef C elementtype; \
typedef X<C> producttype; \
typedef typename LA_traits<C>::normtype normtype; \
2010-06-25 17:28:19 +02:00
typedef X<typename LA_traits<C>::realtype> realtype; \
typedef X<typename LA_traits<C>::complextype> complextype; \
2013-11-04 15:56:39 +01:00
static bool gencmp(const C *x, const C *y, size_t n) {for(size_t i=0; i<n; ++i) if(x[i]!=y[i]) return true; return false;} \
2005-09-06 17:55:07 +02:00
static inline bool bigger(const C &x, const C &y) {return x>y;} \
static inline bool smaller(const C &x, const C &y) {return x<y;} \
static inline normtype norm (const X<C> &x) {return x.norm();} \
static inline void axpy (X<C>&s, const X<C> &x, const C c) {s.axpy(c,x);} \
2010-01-07 17:10:12 +01:00
static void put(int fd, const X<C> &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions,transp);} \
static void get(int fd, X<C> &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions,transp);} \
2010-12-23 12:30:00 +01:00
static void multiput(size_t n,int fd, const X<C> *x, bool dimensions=1) {for(size_t i=0; i<n; ++i) x[i].put(fd,dimensions);} \
static void multiget(size_t n,int fd, X<C> *x, bool dimensions=1) {for(size_t i=0; i<n; ++i) x[i].get(fd,dimensions);} \
2013-11-04 15:56:39 +01:00
static void copy(C *dest, C *src, size_t n) {for(size_t i=0; i<n; ++i) dest[i]=src[i];} \
static void clear(C *dest, size_t n) {for(size_t i=0; i<n; ++i) dest[i].clear();}\
2007-11-29 14:52:31 +01:00
static void copyonwrite(X<C> &x) {x.copyonwrite();}\
2009-11-12 22:01:19 +01:00
static void clearme(X<C> &x) {x.clear();}\
2010-12-23 12:30:00 +01:00
static void deallocate(X<C> &x) {x.dealloc();}\
2004-03-17 04:07:21 +01:00
};
2005-02-14 01:10:07 +01:00
//non-scalar types defined in this library
generate_traits(NRMat)
2006-09-18 23:46:45 +02:00
generate_traits(NRMat_from1)
2005-02-14 01:10:07 +01:00
generate_traits(NRVec)
generate_traits(SparseMat)
2009-11-12 22:01:19 +01:00
generate_traits(SparseSMat) //product leading to non-symmetric result not implemented
2010-12-23 12:30:00 +01:00
generate_traits(CSRMat)
2005-02-14 01:10:07 +01:00
#undef generate_traits
2005-09-06 17:55:07 +02:00
//smat
2006-09-18 23:46:45 +02:00
#define generate_traits_smat(X) \
template<typename C> \
struct LA_traits_aux<X<C>, scalar_false> { \
typedef C elementtype; \
typedef NRMat<C> producttype; \
typedef typename LA_traits<C>::normtype normtype; \
2010-06-25 17:28:19 +02:00
typedef X<typename LA_traits<C>::realtype> realtype; \
typedef X<typename LA_traits<C>::complextype> complextype; \
2013-11-04 15:56:39 +01:00
static bool gencmp(const C *x, const C *y, size_t n) {for(size_t i=0; i<n; ++i) if(x[i]!=y[i]) return true; return false;} \
2006-09-18 23:46:45 +02:00
static inline bool bigger(const C &x, const C &y) {return x>y;} \
static inline bool smaller(const C &x, const C &y) {return x<y;} \
static inline normtype norm (const X<C> &x) {return x.norm();} \
static inline void axpy (X<C>&s, const X<C> &x, const C c) {s.axpy(c,x);} \
2010-01-07 17:10:12 +01:00
static void put(int fd, const X<C> &x, bool dimensions=1, bool transp=0) {x.put(fd,dimensions);} \
static void get(int fd, X<C> &x, bool dimensions=1, bool transp=0) {x.get(fd,dimensions);} \
2010-12-23 12:30:00 +01:00
static void multiput(size_t n,int fd, const X<C> *x, bool dimensions=1) {for(size_t i=0; i<n; ++i) x[i].put(fd,dimensions);} \
static void multiget(size_t n,int fd, X<C> *x, bool dimensions=1) {for(size_t i=0; i<n; ++i) x[i].get(fd,dimensions);} \
2013-11-04 15:56:39 +01:00
static void copy(C *dest, C *src, size_t n) {for(size_t i=0; i<n; ++i) dest[i]=src[i];} \
static void clear(C *dest, size_t n) {for(size_t i=0; i<n; ++i) dest[i].clear();} \
2007-11-29 14:52:31 +01:00
static void copyonwrite(X<C> &x) {x.copyonwrite();} \
2009-11-12 22:01:19 +01:00
static void clearme(X<C> &x) {x.clear();} \
2010-12-23 12:30:00 +01:00
static void deallocate(X<C> &x) {x.dealloc();} \
2004-03-17 04:07:21 +01:00
};
2006-09-18 23:46:45 +02:00
generate_traits_smat(NRSMat)
generate_traits_smat(NRSMat_from1)
2004-03-17 04:07:21 +01:00
2009-10-08 16:01:15 +02:00
2005-02-14 01:10:07 +01:00
//the final traits class
template<typename C>
struct LA_traits : LA_traits_aux<C, typename isscalar<C>::scalar_type> {};
2004-03-17 04:07:21 +01:00
2009-11-12 22:01:19 +01:00
}//namespace
2004-03-17 04:07:21 +01:00
#endif