2004-03-17 04:07:21 +01:00
# ifndef _LA_MAT_H_
# define _LA_MAT_H_
2005-02-14 01:10:07 +01:00
# include "la_traits.h"
2004-03-17 04:07:21 +01:00
template < typename T >
class NRMat {
protected :
int nn ;
int mm ;
# ifdef MATPTR
T * * v ;
# else
T * v ;
# endif
int * count ;
public :
friend class NRVec < T > ;
friend class NRSMat < T > ;
inline NRMat ( ) : nn ( 0 ) , mm ( 0 ) , v ( 0 ) , count ( 0 ) { } ;
inline NRMat ( const int n , const int m ) ;
inline NRMat ( const T & a , const int n , const int m ) ;
NRMat ( const T * a , const int n , const int m ) ;
inline NRMat ( const NRMat & rhs ) ;
explicit NRMat ( const NRSMat < T > & rhs ) ;
# ifndef MATPTR
NRMat ( const NRVec < T > & rhs , const int n , const int m ) ;
# endif
~ NRMat ( ) ;
2005-02-14 01:10:07 +01:00
# ifdef MATPTR
2005-09-06 17:55:07 +02:00
const bool operator ! = ( const NRMat & rhs ) const { if ( nn ! = rhs . nn | | mm ! = rhs . mm ) return 1 ; return LA_traits < T > : : gencmp ( v [ 0 ] , rhs . v [ 0 ] , nn * mm ) ; } //memcmp for scalars else elementwise
2005-02-14 01:10:07 +01:00
# else
2005-09-06 17:55:07 +02:00
const bool operator ! = ( const NRMat & rhs ) const { if ( nn ! = rhs . nn | | mm ! = rhs . mm ) return 1 ; return LA_traits < T > : : gencmp ( v , rhs . v , nn * mm ) ; } //memcmp for scalars else elementwise
2005-02-14 01:10:07 +01:00
# endif
const bool operator = = ( const NRMat & rhs ) const { return ! ( * this ! = rhs ) ; } ;
2004-03-17 04:07:21 +01:00
inline int getcount ( ) const { return count ? * count : 0 ; }
NRMat & operator = ( const NRMat & rhs ) ; //assignment
NRMat & operator = ( const T & a ) ; //assign a to diagonal
NRMat & operator | = ( const NRMat & rhs ) ; //assignment to a new copy
NRMat & operator + = ( const T & a ) ; //add diagonal
NRMat & operator - = ( const T & a ) ; //substract diagonal
NRMat & operator * = ( const T & a ) ; //multiply by a scalar
NRMat & operator + = ( const NRMat & rhs ) ;
NRMat & operator - = ( const NRMat & rhs ) ;
NRMat & operator + = ( const NRSMat < T > & rhs ) ;
NRMat & operator - = ( const NRSMat < T > & rhs ) ;
const NRMat operator - ( ) const ; //unary minus
inline const NRMat operator + ( const T & a ) const ;
inline const NRMat operator - ( const T & a ) const ;
inline const NRMat operator * ( const T & a ) const ;
inline const NRMat operator + ( const NRMat & rhs ) const ;
inline const NRMat operator - ( const NRMat & rhs ) const ;
inline const NRMat operator + ( const NRSMat < T > & rhs ) const ;
inline const NRMat operator - ( const NRSMat < T > & rhs ) const ;
2005-02-18 23:08:15 +01:00
const T dot ( const NRMat & rhs ) const ; // scalar product of Mat.Mat//@@@for complex do conjugate
2004-03-17 04:07:21 +01:00
const NRMat operator * ( const NRMat & rhs ) const ; // Mat * Mat
2005-02-14 01:10:07 +01:00
const NRMat oplus ( const NRMat & rhs ) const ; //direct sum
const NRMat otimes ( const NRMat & rhs ) const ; //direct product
2004-03-17 04:07:21 +01:00
void diagmultl ( const NRVec < T > & rhs ) ; //multiply by a diagonal matrix from L
void diagmultr ( const NRVec < T > & rhs ) ; //multiply by a diagonal matrix from R
const NRMat operator * ( const NRSMat < T > & rhs ) const ; // Mat * Smat
const NRMat operator & ( const NRMat & rhs ) const ; // direct sum
const NRMat operator | ( const NRMat < T > & rhs ) const ; // direct product
2005-02-18 23:08:15 +01:00
const NRVec < T > operator * ( const NRVec < T > & rhs ) const { NRVec < T > result ( nn ) ; result . gemv ( ( T ) 0 , * this , ' n ' , ( T ) 1 , rhs ) ; return result ; } ; // Mat * Vec
2004-03-17 04:07:21 +01:00
const NRVec < T > rsum ( ) const ; //sum of rows
const NRVec < T > csum ( ) const ; //sum of columns
2005-02-04 15:31:42 +01:00
void diagonalof ( NRVec < T > & , const bool divide = 0 ) const ; //get diagonal
2004-03-17 04:07:21 +01:00
inline T * operator [ ] ( const int i ) ; //subscripting: pointer to row i
inline const T * operator [ ] ( const int i ) const ;
inline T & operator ( ) ( const int i , const int j ) ; // (i,j) subscripts
inline const T & operator ( ) ( const int i , const int j ) const ;
inline int nrows ( ) const ;
inline int ncols ( ) const ;
2005-02-18 23:08:15 +01:00
inline int size ( ) const ;
2005-02-14 01:10:07 +01:00
void get ( int fd , bool dimensions = 1 ) ;
void put ( int fd , bool dimensions = 1 ) const ;
2004-03-17 04:07:21 +01:00
void copyonwrite ( ) ;
void resize ( const int n , const int m ) ;
inline operator T * ( ) ; //get a pointer to the data
inline operator const T * ( ) const ;
2005-02-18 23:08:15 +01:00
NRMat & transposeme ( int n = 0 ) ; // square matrices only
2004-03-17 04:07:21 +01:00
NRMat & conjugateme ( ) ; // square matrices only
const NRMat transpose ( bool conj = false ) const ;
const NRMat conjugate ( ) const ;
void gemm ( const T & beta , const NRMat & a , const char transa , const NRMat & b ,
const char transb , const T & alpha ) ; //this = alpha*op( A )*op( B ) + beta*this
/*
void strassen ( const T beta , const NRMat & a , const char transa , const NRMat & b ,
const char transb , const T alpha ) ; //this := alpha*op( A )*op( B ) + beta*this
void s_cutoff ( const int , const int , const int , const int ) const ;
*/
void fprintf ( FILE * f , const char * format , const int modulo ) const ;
void fscanf ( FILE * f , const char * format ) ;
const double norm ( const T scalar = ( T ) 0 ) const ;
void axpy ( const T alpha , const NRMat & x ) ; // this += a*x
inline const T amax ( ) const ;
const T trace ( ) const ;
//members concerning sparse matrix
explicit NRMat ( const SparseMat < T > & rhs ) ; // dense from sparse
NRMat & operator + = ( const SparseMat < T > & rhs ) ;
NRMat & operator - = ( const SparseMat < T > & rhs ) ;
2005-09-06 17:55:07 +02:00
void gemm ( const T & beta , const SparseMat < T > & a , const char transa , const NRMat & b , const char transb , const T & alpha ) ; //this = alpha*op( A )*op( B ) + beta*this
2004-03-17 04:07:21 +01:00
inline void simplify ( ) { } ; //just for compatibility with sparse ones
2005-02-18 23:08:15 +01:00
bool issymmetric ( ) const { return 0 ; } ;
2004-03-17 04:07:21 +01:00
//Strassen's multiplication (better than n^3, analogous syntax to gemm)
void strassen ( const T beta , const NRMat & a , const char transa , const NRMat & b , const char transb , const T alpha ) ; //this := alpha*op( A )*op( B ) + beta*this
void s_cutoff ( const int , const int , const int , const int ) const ;
} ;
2005-02-18 23:08:15 +01:00
//due to mutual includes this has to be after full class declaration
# include "vec.h"
# include "smat.h"
# include "sparsemat.h"
2004-03-17 04:07:21 +01:00
// ctors
template < typename T >
NRMat < T > : : NRMat ( const int n , const int m ) : nn ( n ) , mm ( m ) , count ( new int )
{
* count = 1 ;
# ifdef MATPTR
v = new T * [ n ] ;
v [ 0 ] = new T [ m * n ] ;
for ( int i = 1 ; i < n ; i + + ) v [ i ] = v [ i - 1 ] + m ;
# else
v = new T [ m * n ] ;
# endif
}
template < typename T >
NRMat < T > : : NRMat ( const T & a , const int n , const int m ) : nn ( n ) , mm ( m ) , count ( new int )
{
int i ;
T * p ;
* count = 1 ;
# ifdef MATPTR
v = new T * [ n ] ;
p = v [ 0 ] = new T [ m * n ] ;
for ( int i = 1 ; i < n ; i + + ) v [ i ] = v [ i - 1 ] + m ;
# else
p = v = new T [ m * n ] ;
# endif
if ( a ! = ( T ) 0 )
for ( i = 0 ; i < n * m ; i + + ) * p + + = a ;
else
memset ( p , 0 , n * m * sizeof ( T ) ) ;
}
template < typename T >
NRMat < T > : : NRMat ( const T * a , const int n , const int m ) : nn ( n ) , mm ( m ) , count ( new int )
{
* count = 1 ;
# ifdef MATPTR
v = new T * [ n ] ;
v [ 0 ] = new T [ m * n ] ;
for ( int i = 1 ; i < n ; i + + ) v [ i ] = v [ i - 1 ] + m ;
memcpy ( v [ 0 ] , a , n * m * sizeof ( T ) ) ;
# else
v = new T [ m * n ] ;
memcpy ( v , a , n * m * sizeof ( T ) ) ;
# endif
}
template < typename T >
NRMat < T > : : NRMat ( const NRMat & rhs )
{
nn = rhs . nn ;
mm = rhs . mm ;
count = rhs . count ;
v = rhs . v ;
if ( count ) + + ( * count ) ;
}
template < typename T >
NRMat < T > : : NRMat ( const NRSMat < T > & rhs )
{
int i ;
nn = mm = rhs . nrows ( ) ;
count = new int ;
* count = 1 ;
# ifdef MATPTR
v = new T * [ nn ] ;
v [ 0 ] = new T [ mm * nn ] ;
for ( int i = 1 ; i < nn ; i + + ) v [ i ] = v [ i - 1 ] + mm ;
# else
v = new T [ mm * nn ] ;
# endif
int j , k = 0 ;
# ifdef MATPTR
for ( i = 0 ; i < nn ; i + + )
for ( j = 0 ; j < = i ; j + + ) v [ i ] [ j ] = v [ j ] [ i ] = rhs [ k + + ] ;
# else
for ( i = 0 ; i < nn ; i + + )
for ( j = 0 ; j < = i ; j + + ) v [ i * nn + j ] = v [ j * nn + i ] = rhs [ k + + ] ;
# endif
}
# ifndef MATPTR
template < typename T >
NRMat < T > : : NRMat ( const NRVec < T > & rhs , const int n , const int m )
{
# ifdef DEBUG
if ( n * m ! = rhs . nn ) laerror ( " matrix dimensions incompatible with vector length " ) ;
# endif
nn = n ;
mm = m ;
count = rhs . count ;
v = rhs . v ;
( * count ) + + ;
}
# endif
// Mat + Smat
template < typename T >
inline const NRMat < T > NRMat < T > : : operator + ( const NRSMat < T > & rhs ) const
{
return NRMat < T > ( * this ) + = rhs ;
}
// Mat - Smat
template < typename T >
inline const NRMat < T > NRMat < T > : : operator - ( const NRSMat < T > & rhs ) const
{
return NRMat < T > ( * this ) - = rhs ;
}
// Mat[i] : pointer to the first element of i-th row
template < typename T >
inline T * NRMat < T > : : operator [ ] ( const int i )
{
# ifdef DEBUG
if ( * count ! = 1 ) laerror ( " Mat lval use of [] with count > 1 " ) ;
if ( i < 0 | | i > = nn ) laerror ( " Mat [] out of range " ) ;
if ( ! v ) laerror ( " [] for unallocated Mat " ) ;
# endif
# ifdef MATPTR
return v [ i ] ;
# else
return v + i * mm ;
# endif
}
template < typename T >
inline const T * NRMat < T > : : operator [ ] ( const int i ) const
{
# ifdef DEBUG
if ( i < 0 | | i > = nn ) laerror ( " Mat [] out of range " ) ;
if ( ! v ) laerror ( " [] for unallocated Mat " ) ;
# endif
# ifdef MATPTR
return v [ i ] ;
# else
return v + i * mm ;
# endif
}
// Mat(i,j) reference to the matrix element M_{ij}
template < typename T >
inline T & NRMat < T > : : operator ( ) ( const int i , const int j )
{
# ifdef DEBUG
if ( * count ! = 1 ) laerror ( " Mat lval use of (,) with count > 1 " ) ;
if ( i < 0 | | i > = nn | | j < 0 | | j > mm ) laerror ( " Mat (,) out of range " ) ;
if ( ! v ) laerror ( " (,) for unallocated Mat " ) ;
# endif
# ifdef MATPTR
return v [ i ] [ j ] ;
# else
return v [ i * mm + j ] ;
# endif
}
template < typename T >
inline const T & NRMat < T > : : operator ( ) ( const int i , const int j ) const
{
# ifdef DEBUG
if ( i < 0 | | i > = nn | | j < 0 | | j > mm ) laerror ( " Mat (,) out of range " ) ;
if ( ! v ) laerror ( " (,) for unallocated Mat " ) ;
# endif
# ifdef MATPTR
return v [ i ] [ j ] ;
# else
return v [ i * mm + j ] ;
# endif
}
// number of rows
template < typename T >
inline int NRMat < T > : : nrows ( ) const
{
return nn ;
}
// number of columns
template < typename T >
inline int NRMat < T > : : ncols ( ) const
{
return mm ;
}
2005-02-18 23:08:15 +01:00
template < typename T >
inline int NRMat < T > : : size ( ) const
{
return nn * mm ;
}
2004-03-17 04:07:21 +01:00
// reference pointer to Mat
template < typename T >
inline NRMat < T > : : operator T * ( )
{
# ifdef DEBUG
if ( ! v ) laerror ( " unallocated Mat in operator T* " ) ;
# endif
# ifdef MATPTR
return v [ 0 ] ;
# else
return v ;
# endif
}
template < typename T >
inline NRMat < T > : : operator const T * ( ) const
{
# ifdef DEBUG
if ( ! v ) laerror ( " unallocated Mat in operator T* " ) ;
# endif
# ifdef MATPTR
return v [ 0 ] ;
# else
return v ;
# endif
}
// max element of Mat
inline const double NRMat < double > : : amax ( ) const
{
# ifdef MATPTR
return v [ 0 ] [ cblas_idamax ( nn * mm , v [ 0 ] , 1 ) ] ;
# else
return v [ cblas_idamax ( nn * mm , v , 1 ) ] ;
# endif
}
inline const complex < double > NRMat < complex < double > > : : amax ( ) const
{
# ifdef MATPTR
return v [ 0 ] [ cblas_izamax ( nn * mm , ( void * ) v [ 0 ] , 1 ) ] ;
# else
return v [ cblas_izamax ( nn * mm , ( void * ) v , 1 ) ] ;
# endif
}
2004-03-17 17:39:07 +01:00
//basi stuff to be available for any type ... must be in .h
// dtor
template < typename T >
NRMat < T > : : ~ NRMat ( )
{
if ( ! count ) return ;
if ( - - ( * count ) < = 0 ) {
if ( v ) {
# ifdef MATPTR
delete [ ] ( v [ 0 ] ) ;
# endif
delete [ ] v ;
}
delete count ;
}
}
// assign NRMat = NRMat
template < typename T >
NRMat < T > & NRMat < T > : : operator = ( const NRMat < T > & rhs )
{
2005-02-01 00:08:03 +01:00
if ( this ! = & rhs )
{
if ( count )
if ( - - ( * count ) = = 0 ) {
2004-03-17 17:39:07 +01:00
# ifdef MATPTR
delete [ ] ( v [ 0 ] ) ;
# endif
delete [ ] v ;
delete count ;
2005-02-01 00:08:03 +01:00
}
2004-03-17 17:39:07 +01:00
v = rhs . v ;
nn = rhs . nn ;
mm = rhs . mm ;
count = rhs . count ;
2004-03-24 17:25:47 +01:00
if ( count ) ( * count ) + + ;
2005-02-01 00:08:03 +01:00
}
2004-03-17 17:39:07 +01:00
return * this ;
}
// Explicit deep copy of NRmat
template < typename T >
NRMat < T > & NRMat < T > : : operator | = ( const NRMat < T > & rhs )
{
if ( this = = & rhs ) return * this ;
# ifdef DEBUG
if ( ! rhs . v ) laerror ( " unallocated rhs in Mat operator |= " ) ;
# endif
if ( count )
if ( * count > 1 ) {
- - ( * count ) ;
nn = 0 ;
mm = 0 ;
count = 0 ;
v = 0 ;
}
if ( nn ! = rhs . nn | | mm ! = rhs . mm ) {
if ( v ) {
# ifdef MATPTR
delete [ ] ( v [ 0 ] ) ;
# endif
delete [ ] ( v ) ;
v = 0 ;
}
nn = rhs . nn ;
mm = rhs . mm ;
}
if ( ! v ) {
# ifdef MATPTR
v = new T * [ nn ] ;
v [ 0 ] = new T [ mm * nn ] ;
# else
v = new T [ mm * nn ] ;
# endif
}
# ifdef MATPTR
for ( int i = 1 ; i < nn ; i + + ) v [ i ] = v [ i - 1 ] + mm ;
memcpy ( v [ 0 ] , rhs . v [ 0 ] , nn * mm * sizeof ( T ) ) ;
# else
memcpy ( v , rhs . v , nn * mm * sizeof ( T ) ) ;
# endif
if ( ! count ) count = new int ;
* count = 1 ;
return * this ;
}
// make detach Mat and make it's own deep copy
template < typename T >
void NRMat < T > : : copyonwrite ( )
{
# ifdef DEBUG
if ( ! count ) laerror ( " Mat::copyonwrite of undefined matrix " ) ;
# endif
if ( * count > 1 ) {
( * count ) - - ;
count = new int ;
* count = 1 ;
# ifdef MATPTR
T * * newv = new T * [ nn ] ;
newv [ 0 ] = new T [ mm * nn ] ;
memcpy ( newv [ 0 ] , v [ 0 ] , mm * nn * sizeof ( T ) ) ;
v = newv ;
for ( int i = 1 ; i < nn ; i + + ) v [ i ] = v [ i - 1 ] + mm ;
# else
T * newv = new T [ mm * nn ] ;
memcpy ( newv , v , mm * nn * sizeof ( T ) ) ;
v = newv ;
# endif
}
}
template < typename T >
void NRMat < T > : : resize ( const int n , const int m )
{
# ifdef DEBUG
2005-02-14 01:10:07 +01:00
if ( n < 0 | | m < 0 | | n > 0 & & m = = 0 | | n = = 0 & & m > 0 ) laerror ( " illegal dimensions in Mat::resize() " ) ;
2004-03-17 17:39:07 +01:00
# endif
if ( count )
2005-02-14 01:10:07 +01:00
{
if ( n = = 0 & & m = = 0 )
{
if ( - - ( * count ) < = 0 ) {
# ifdef MATPTR
if ( v ) delete [ ] ( v [ 0 ] ) ;
# endif
if ( v ) delete [ ] v ;
delete count ;
}
count = 0 ;
nn = mm = 0 ;
v = 0 ;
return ;
}
2004-03-17 17:39:07 +01:00
if ( * count > 1 ) {
( * count ) - - ;
count = 0 ;
v = 0 ;
nn = 0 ;
mm = 0 ;
}
2005-02-14 01:10:07 +01:00
}
2004-03-17 17:39:07 +01:00
if ( ! count ) {
count = new int ;
* count = 1 ;
nn = n ;
mm = m ;
# ifdef MATPTR
v = new T * [ nn ] ;
v [ 0 ] = new T [ m * n ] ;
for ( int i = 1 ; i < n ; i + + ) v [ i ] = v [ i - 1 ] + m ;
# else
v = new T [ m * n ] ;
# endif
return ;
}
// At this point *count = 1, check if resize is necessary
if ( n ! = nn | | m ! = mm ) {
nn = n ;
mm = m ;
# ifdef MATPTR
delete [ ] ( v [ 0 ] ) ;
# endif
delete [ ] v ;
# ifdef MATPTR
v = new T * [ nn ] ;
v [ 0 ] = new T [ m * n ] ;
for ( int i = 1 ; i < n ; i + + ) v [ i ] = v [ i - 1 ] + m ;
# else
v = new T [ m * n ] ;
# endif
}
}
2005-02-06 15:01:27 +01:00
2004-03-17 04:07:21 +01:00
// I/O
template < typename T > extern ostream & operator < < ( ostream & s , const NRMat < T > & x ) ;
template < typename T > extern istream & operator > > ( istream & s , NRMat < T > & x ) ;
// generate operators: Mat + a, a + Mat, Mat * a
NRVECMAT_OPER ( Mat , + )
NRVECMAT_OPER ( Mat , - )
NRVECMAT_OPER ( Mat , * )
// generate Mat + Mat, Mat - Mat
NRVECMAT_OPER2 ( Mat , + )
NRVECMAT_OPER2 ( Mat , - )
# endif /* _LA_MAT_H_ */