2004-03-17 04:07:21 +01:00
# include "mat.h"
2005-02-14 01:10:07 +01:00
# include <stdlib.h>
# include <sys/types.h>
# include <sys/stat.h>
# include <fcntl.h>
2005-12-08 13:06:23 +01:00
# include <errno.h>
2005-02-14 01:10:07 +01:00
extern " C " {
extern ssize_t read ( int , void * , size_t ) ;
extern ssize_t write ( int , const void * , size_t ) ;
}
2004-03-17 04:07:21 +01:00
// TODO :
//
/*
* Templates first , specializations for BLAS next
2006-04-01 06:48:01 +02:00
*/
2007-06-22 16:24:55 +02:00
//direct sum
template < typename T >
const NRMat < T > NRMat < T > : : oplus ( const NRMat < T > & rhs ) const
{
NRMat < T > r ( ( T ) 0 , nn + rhs . nn , mm + rhs . mm ) ;
# ifdef oldversion
int i , j ;
for ( i = 0 ; i < nn ; i + + ) for ( j = 0 ; j < mm ; j + + ) r ( i , j ) = ( * this ) ( i , j ) ;
for ( i = 0 ; i < nn ; i + + ) for ( j = mm ; j < mm + rhs . mm ; j + + ) r ( i , j ) = ( T ) 0 ;
for ( i = nn ; i < nn + rhs . nn ; i + + ) for ( j = 0 ; j < mm ; j + + ) r ( i , j ) = ( T ) 0 ;
for ( i = nn ; i < nn + rhs . nn ; i + + ) for ( j = mm ; j < mm + rhs . mm ; j + + ) r ( i , j ) = rhs ( i - nn , j - mm ) ;
# else
r . storesubmatrix ( 0 , 0 , * this ) ;
r . storesubmatrix ( nn , mm , rhs ) ;
# endif
return r ;
}
//direct product
template < typename T >
const NRMat < T > NRMat < T > : : otimes ( const NRMat < T > & rhs ) const
{
NRMat < T > r ( ( T ) 0 , nn * rhs . nn , mm * rhs . mm ) ;
int i , j , k , l ;
for ( i = 0 ; i < nn ; i + + ) for ( j = 0 ; j < mm ; j + + )
{
T c = ( * this ) ( i , j ) ;
for ( k = 0 ; k < rhs . mm ; k + + ) for ( l = 0 ; l < rhs . mm ; l + + )
r ( i * rhs . nn + k , j * rhs . nn + l ) = c * rhs ( k , l ) ;
}
return r ;
}
2006-04-01 06:48:01 +02:00
//row of
template < typename T >
2006-09-13 23:29:28 +02:00
const NRVec < T > NRMat < T > : : row ( const int i , int l ) const
2006-04-01 06:48:01 +02:00
{
# ifdef DEBUG
if ( i < 0 | | i > = nn ) laerror ( " illegal index in row() " ) ;
# endif
2006-09-13 23:29:28 +02:00
if ( l < 0 ) l = mm ;
NRVec < T > r ( l ) ;
2006-04-01 06:48:01 +02:00
LA_traits < T > : : copy ( & r [ 0 ] ,
# ifdef MATPTR
v [ i ]
# else
2006-09-13 23:29:28 +02:00
v + i * l
2006-04-01 06:48:01 +02:00
# endif
2006-09-13 23:29:28 +02:00
, l ) ;
2006-04-01 06:48:01 +02:00
return r ;
}
2004-03-17 04:07:21 +01:00
2005-02-14 01:10:07 +01:00
//raw I/O
template < typename T >
2005-09-11 22:04:24 +02:00
void NRMat < T > : : put ( int fd , bool dim , bool transp ) const
2005-02-14 01:10:07 +01:00
{
errno = 0 ;
if ( dim )
{
2005-09-11 22:04:24 +02:00
if ( sizeof ( int ) ! = write ( fd , & ( transp ? mm : nn ) , sizeof ( int ) ) ) laerror ( " cannot write " ) ;
if ( sizeof ( int ) ! = write ( fd , & ( transp ? nn : mm ) , sizeof ( int ) ) ) laerror ( " cannot write " ) ;
2005-02-14 01:10:07 +01:00
}
2005-09-11 22:04:24 +02:00
if ( transp ) //not particularly efficient
{
for ( int j = 0 ; j < mm ; + + j )
for ( int i = 0 ; i < nn ; + + i )
LA_traits < T > : : put ( fd ,
# ifdef MATPTR
v [ i ] [ j ]
# else
v [ i * mm + j ]
# endif
, dim , transp ) ;
}
else LA_traits < T > : : multiput ( nn * mm , fd ,
2005-02-14 01:10:07 +01:00
# ifdef MATPTR
v [ 0 ]
# else
v
# endif
, dim ) ;
}
template < typename T >
2005-09-11 22:04:24 +02:00
void NRMat < T > : : get ( int fd , bool dim , bool transp )
2005-02-14 01:10:07 +01:00
{
int nn0 , mm0 ;
errno = 0 ;
if ( dim )
{
if ( sizeof ( int ) ! = read ( fd , & nn0 , sizeof ( int ) ) ) laerror ( " cannot read " ) ;
if ( sizeof ( int ) ! = read ( fd , & mm0 , sizeof ( int ) ) ) laerror ( " cannot read " ) ;
2005-09-11 22:04:24 +02:00
if ( transp ) resize ( mm0 , nn0 ) ; else resize ( nn0 , mm0 ) ;
2005-02-14 01:10:07 +01:00
}
else
copyonwrite ( ) ;
2005-09-11 22:04:24 +02:00
if ( transp ) //not particularly efficient
{
for ( int j = 0 ; j < mm ; + + j )
for ( int i = 0 ; i < nn ; + + i )
LA_traits < T > : : get ( fd ,
# ifdef MATPTR
v [ i ] [ j ]
# else
v [ i * mm + j ]
# endif
, dim , transp ) ;
}
else LA_traits < T > : : multiget ( nn * mm , fd ,
2005-02-14 01:10:07 +01:00
# ifdef MATPTR
v [ 0 ]
# else
v
# endif
, dim ) ;
}
2004-03-17 04:07:21 +01:00
// Assign diagonal
template < typename T >
NRMat < T > & NRMat < T > : : operator = ( const T & a )
{
copyonwrite ( ) ;
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " RMat.operator=scalar on non-square matrix " ) ;
# endif
# ifdef MATPTR
2006-09-04 22:12:34 +02:00
memset ( v [ 0 ] , 0 , nn * nn * sizeof ( T ) ) ;
2004-03-17 04:07:21 +01:00
for ( int i = 0 ; i < nn ; i + + ) v [ i ] [ i ] = a ;
# else
2006-09-04 22:12:34 +02:00
memset ( v , 0 , nn * nn * sizeof ( T ) ) ;
2004-03-17 04:07:21 +01:00
for ( int i = 0 ; i < nn * nn ; i + = nn + 1 ) v [ i ] = a ;
# endif
return * this ;
}
2005-02-01 00:08:03 +01:00
2004-03-17 04:07:21 +01:00
// M += a
template < typename T >
NRMat < T > & NRMat < T > : : operator + = ( const T & a )
{
copyonwrite ( ) ;
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " Mat.operator+=scalar on non-square matrix " ) ;
# endif
# ifdef MATPTR
for ( int i = 0 ; i < nn ; i + + ) v [ i ] [ i ] + = a ;
# else
for ( int i = 0 ; i < nn * nn ; i + = nn + 1 ) v [ i ] + = a ;
# endif
return * this ;
}
// M -= a
template < typename T >
NRMat < T > & NRMat < T > : : operator - = ( const T & a )
{
copyonwrite ( ) ;
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " Mat.operator-=scalar on non-square matrix " ) ;
# endif
# ifdef MATPTR
for ( int i = 0 ; i < nn ; i + + ) v [ i ] [ i ] - = a ;
# else
for ( int i = 0 ; i < nn * nn ; i + = nn + 1 ) v [ i ] - = a ;
# endif
return * this ;
}
// unary minus
template < typename T >
const NRMat < T > NRMat < T > : : operator - ( ) const
{
NRMat < T > result ( nn , mm ) ;
# ifdef MATPTR
for ( int i = 0 ; i < nn * mm ; i + + ) result . v [ 0 ] [ i ] = - v [ 0 ] [ i ] ;
# else
for ( int i = 0 ; i < nn * mm ; i + + ) result . v [ i ] = - v [ i ] ;
# endif
return result ;
}
// direct sum
template < typename T >
const NRMat < T > NRMat < T > : : operator & ( const NRMat < T > & b ) const
{
NRMat < T > result ( ( T ) 0 , nn + b . nn , mm + b . mm ) ;
for ( int i = 0 ; i < nn ; i + + ) memcpy ( result [ i ] , ( * this ) [ i ] , sizeof ( T ) * mm ) ;
for ( int i = 0 ; i < b . nn ; i + + ) memcpy ( result [ nn + i ] + nn , b [ i ] , sizeof ( T ) * b . mm ) ;
return result ;
}
// direct product
template < typename T >
const NRMat < T > NRMat < T > : : operator | ( const NRMat < T > & b ) const
{
NRMat < T > result ( nn * b . nn , mm * b . mm ) ;
for ( int i = 0 ; i < nn ; i + + )
for ( int j = 0 ; j < mm ; j + + )
for ( int k = 0 ; k < b . nn ; k + + )
for ( int l = 0 ; l < b . mm ; l + + )
result [ i * b . nn + k ] [ j * b . mm + l ] = ( * this ) [ i ] [ j ] * b [ k ] [ l ] ;
return result ;
}
// sum of columns
template < typename T >
const NRVec < T > NRMat < T > : : csum ( ) const
{
NRVec < T > result ( nn ) ;
T sum ;
for ( int i = 0 ; i < nn ; i + + ) {
sum = ( T ) 0 ;
for ( int j = 0 ; j < mm ; j + + ) sum + = ( * this ) [ i ] [ j ] ;
result [ i ] = sum ;
}
return result ;
}
// sum of rows
template < typename T >
const NRVec < T > NRMat < T > : : rsum ( ) const
{
NRVec < T > result ( nn ) ;
T sum ;
for ( int i = 0 ; i < mm ; i + + ) {
sum = ( T ) 0 ;
for ( int j = 0 ; j < nn ; j + + ) sum + = ( * this ) [ j ] [ i ] ;
result [ i ] = sum ;
}
return result ;
}
2005-09-11 22:04:24 +02:00
//block submatrix
template < typename T >
const NRMat < T > NRMat < T > : : submatrix ( const int fromrow , const int torow , const int fromcol , const int tocol ) const
{
# ifdef DEBUG
if ( fromrow < 0 | | fromrow > = nn | | torow < 0 | | torow > = nn | | fromcol < 0 | | fromcol > = mm | | tocol < 0 | | tocol > = mm | | fromrow > torow | | fromcol > tocol ) laerror ( " bad indices in submatrix " ) ;
# endif
int n = torow - fromrow + 1 ;
int m = tocol - fromcol + 1 ;
NRMat < T > r ( n , m ) ;
for ( int i = fromrow ; i < = torow ; + + i )
# ifdef MATPTR
memcpy ( r . v [ i - fromrow ] , v [ i ] + fromcol , m * sizeof ( T ) ) ;
# else
memcpy ( r . v + ( i - fromrow ) * m , v + i * mm + fromcol , m * sizeof ( T ) ) ;
# endif
return r ;
}
2004-03-17 04:07:21 +01:00
2006-10-21 22:14:13 +02:00
template < typename T >
void NRMat < T > : : storesubmatrix ( const int fromrow , const int fromcol , const NRMat & rhs )
{
int tocol = fromcol + rhs . ncols ( ) - 1 ;
int torow = fromrow + rhs . nrows ( ) - 1 ;
# ifdef DEBUG
if ( fromrow < 0 | | fromrow > = nn | | torow > = nn | | fromcol < 0 | | fromcol > = mm | | tocol > = mm ) laerror ( " bad indices in storesubmatrix " ) ;
# endif
int m = tocol - fromcol + 1 ;
for ( int i = fromrow ; i < = torow ; + + i )
# ifdef MATPTR
memcpy ( v [ i ] + fromcol , rhs . v [ i - fromrow ] , m * sizeof ( T ) ) ;
# else
memcpy ( v + i * mm + fromcol , rhs . v + ( i - fromrow ) * m , m * sizeof ( T ) ) ;
# endif
}
2004-03-17 04:07:21 +01:00
// transpose Mat
template < typename T >
2005-02-18 23:08:15 +01:00
NRMat < T > & NRMat < T > : : transposeme ( int n )
2004-03-17 04:07:21 +01:00
{
2005-02-18 23:08:15 +01:00
if ( n = = 0 ) n = nn ;
2004-03-17 04:07:21 +01:00
# ifdef DEBUG
2005-02-18 23:08:15 +01:00
if ( n = = nn & & nn ! = mm | | n > mm | | n > nn ) laerror ( " transpose of non-square Mat " ) ;
2004-03-17 04:07:21 +01:00
# endif
copyonwrite ( ) ;
2005-02-18 23:08:15 +01:00
for ( int i = 1 ; i < n ; i + + )
2004-03-17 04:07:21 +01:00
for ( int j = 0 ; j < i ; j + + ) {
# ifdef MATPTR
T tmp = v [ i ] [ j ] ;
v [ i ] [ j ] = v [ j ] [ i ] ;
v [ j ] [ i ] = tmp ;
# else
register int a ;
register int b ;
a = i * mm + j ;
b = j * mm + i ;
T tmp = v [ a ] ;
v [ a ] = v [ b ] ;
v [ b ] = tmp ;
# endif
}
return * this ;
}
// Output of Mat
template < typename T >
void NRMat < T > : : fprintf ( FILE * file , const char * format , const int modulo ) const
{
lawritemat ( file , ( const T * ) ( * this ) , nn , mm , format , 2 , modulo , 0 ) ;
}
// Input of Mat
template < typename T >
void NRMat < T > : : fscanf ( FILE * f , const char * format )
{
int n , m ;
if ( std : : fscanf ( f , " %d %d " , & n , & m ) ! = 2 )
laerror ( " cannot read matrix dimensions in Mat::fscanf() " ) ;
resize ( n , m ) ;
T * p = * this ;
for ( int i = 0 ; i < n ; i + + )
for ( int j = 0 ; j < n ; j + + )
if ( std : : fscanf ( f , format , p + + ) ! = 1 )
laerror ( " cannot read matrix element in Mat::fscanf() " ) ;
}
/*
* BLAS specializations for double and complex < double >
*/
2006-04-01 06:48:01 +02:00
template < >
const NRSMat < double > NRMat < double > : : transposedtimes ( ) const
{
NRSMat < double > r ( mm , mm ) ;
int i , j ;
for ( i = 0 ; i < mm ; + + i ) for ( j = 0 ; j < = i ; + + j )
# ifdef MATPTR
r ( i , j ) = cblas_ddot ( nn , v [ 0 ] + i , mm , v [ 0 ] + j , mm ) ;
# else
r ( i , j ) = cblas_ddot ( nn , v + i , mm , v + j , mm ) ;
# endif
return r ;
}
template < >
const NRSMat < complex < double > > NRMat < complex < double > > : : transposedtimes ( ) const
{
NRSMat < complex < double > > r ( mm , mm ) ;
int i , j ;
for ( i = 0 ; i < mm ; + + i ) for ( j = 0 ; j < = i ; + + j )
# ifdef MATPTR
cblas_zdotc_sub ( nn , v [ 0 ] + i , mm , v [ 0 ] + j , mm , ( void * ) ( & r ( i , j ) ) ) ;
# else
cblas_zdotc_sub ( nn , v + i , mm , v + j , mm , ( void * ) ( & r ( i , j ) ) ) ;
# endif
return r ;
}
//and for general type
template < typename T >
const NRSMat < T > NRMat < T > : : transposedtimes ( ) const
{
NRSMat < T > r ( mm , mm ) ;
int i , j ;
for ( i = 0 ; i < mm ; + + i ) for ( j = 0 ; j < = i ; + + j )
{
T s = ( T ) 0 ;
for ( int k = 0 ; k < nn ; + + k ) s + = ( * this ) ( k , i ) * ( * this ) ( k , j ) ;
r ( i , j ) = s ;
}
return r ;
}
template < >
const NRSMat < double > NRMat < double > : : timestransposed ( ) const
{
NRSMat < double > r ( nn , nn ) ;
int i , j ;
for ( i = 0 ; i < nn ; + + i ) for ( j = 0 ; j < = i ; + + j )
# ifdef MATPTR
r ( i , j ) = cblas_ddot ( mm , v [ i ] , 1 , v [ j ] , 1 ) ;
# else
r ( i , j ) = cblas_ddot ( mm , v + i * mm , 1 , v + j * mm , 1 ) ;
# endif
return r ;
}
template < >
const NRSMat < complex < double > > NRMat < complex < double > > : : timestransposed ( ) const
{
NRSMat < complex < double > > r ( nn , nn ) ;
int i , j ;
for ( i = 0 ; i < nn ; + + i ) for ( j = 0 ; j < = i ; + + j )
# ifdef MATPTR
cblas_zdotc_sub ( mm , v [ i ] , 1 , v [ j ] , 1 , ( void * ) ( & r ( i , j ) ) ) ;
# else
cblas_zdotc_sub ( mm , v + i * mm , 1 , v + j * mm , 1 , ( void * ) ( & r ( i , j ) ) ) ;
# endif
return r ;
}
//and for general type
template < typename T >
const NRSMat < T > NRMat < T > : : timestransposed ( ) const
{
NRSMat < T > r ( nn , nn ) ;
int i , j ;
for ( i = 0 ; i < nn ; + + i ) for ( j = 0 ; j < = i ; + + j )
{
T s = ( T ) 0 ;
for ( int k = 0 ; k < mm ; + + k ) s + = ( * this ) ( i , k ) * ( * this ) ( j , k ) ;
r ( i , j ) = s ;
}
return r ;
}
2004-03-17 04:07:21 +01:00
// Mat *= a
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : operator * = ( const double & a )
{
copyonwrite ( ) ;
cblas_dscal ( nn * mm , a , * this , 1 ) ;
return * this ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > &
NRMat < complex < double > > : : operator * = ( const complex < double > & a )
{
copyonwrite ( ) ;
cblas_zscal ( nn * mm , & a , ( void * ) ( * this ) [ 0 ] , 1 ) ;
return * this ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template < typename T >
NRMat < T > & NRMat < T > : : operator * = ( const T & a )
{
copyonwrite ( ) ;
# ifdef MATPTR
for ( int i = 0 ; i < nn * nn ; i + + ) v [ 0 ] [ i ] * = a ;
# else
for ( int i = 0 ; i < nn * nn ; i + + ) v [ i ] * = a ;
# endif
return * this ;
}
2004-03-17 04:07:21 +01:00
// Mat += Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : operator + = ( const NRMat < double > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat += Mat of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
cblas_daxpy ( nn * mm , 1.0 , rhs , 1 , * this , 1 ) ;
return * this ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > &
NRMat < complex < double > > : : operator + = ( const NRMat < complex < double > > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat += Mat of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
cblas_zaxpy ( nn * mm , & CONE , ( void * ) rhs [ 0 ] , 1 , ( void * ) ( * this ) [ 0 ] , 1 ) ;
return * this ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template < typename T >
NRMat < T > & NRMat < T > : : operator + = ( const NRMat < T > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat -= Mat of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
# ifdef MATPTR
for ( int i = 0 ; i < nn * nn ; i + + ) v [ 0 ] [ i ] + = rhs . v [ 0 ] [ i ] ;
# else
for ( int i = 0 ; i < nn * nn ; i + + ) v [ i ] + = rhs . v [ i ] ;
# endif
return * this ;
}
2004-03-17 04:07:21 +01:00
// Mat -= Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : operator - = ( const NRMat < double > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat -= Mat of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
cblas_daxpy ( nn * mm , - 1.0 , rhs , 1 , * this , 1 ) ;
return * this ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > &
NRMat < complex < double > > : : operator - = ( const NRMat < complex < double > > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat -= Mat of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
cblas_zaxpy ( nn * mm , & CMONE , ( void * ) rhs [ 0 ] , 1 , ( void * ) ( * this ) [ 0 ] , 1 ) ;
return * this ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template < typename T >
NRMat < T > & NRMat < T > : : operator - = ( const NRMat < T > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat -= Mat of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
# ifdef MATPTR
for ( int i = 0 ; i < nn * nn ; i + + ) v [ 0 ] [ i ] - = rhs . v [ 0 ] [ i ] ;
# else
for ( int i = 0 ; i < nn * nn ; i + + ) v [ i ] - = rhs . v [ i ] ;
# endif
return * this ;
}
2004-03-17 04:07:21 +01:00
// Mat += SMat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : operator + = ( const NRSMat < double > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat+=SMat " ) ;
# endif
const double * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
cblas_daxpy ( i + 1 , 1.0 , p , 1 , ( * this ) [ i ] , 1 ) ;
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
cblas_daxpy ( i , 1.0 , p , 1 , ( * this ) [ 0 ] + i , nn ) ;
p + = i + 1 ;
}
return * this ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > &
NRMat < complex < double > > : : operator + = ( const NRSMat < complex < double > > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat+=SMat " ) ;
# endif
const complex < double > * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
cblas_zaxpy ( i + 1 , ( void * ) & CONE , ( void * ) p , 1 , ( void * ) ( * this ) [ i ] , 1 ) ;
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
cblas_zaxpy ( i , ( void * ) & CONE , ( void * ) p , 1 , ( void * ) ( ( * this ) [ i ] + i ) , nn ) ;
p + = i + 1 ;
}
return * this ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template < typename T >
NRMat < T > & NRMat < T > : : operator + = ( const NRSMat < T > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat+=SMat " ) ;
# endif
const T * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
for ( int j = 0 ; j < i + 1 ; + + j ) * ( ( * this ) [ i ] + j ) + = p [ j ] ;
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
for ( int j = 0 ; j < i ; + + j ) * ( ( * this ) [ i ] + i + nn * j ) + = p [ j ] ;
p + = i + 1 ;
}
return * this ;
}
2004-03-17 04:07:21 +01:00
// Mat -= SMat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : operator - = ( const NRSMat < double > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat-=SMat " ) ;
# endif
const double * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
cblas_daxpy ( i + 1 , - 1.0 , p , 1 , ( * this ) [ i ] , 1 ) ;
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
cblas_daxpy ( i , - 1.0 , p , 1 , ( * this ) [ 0 ] + i , nn ) ;
p + = i + 1 ;
}
return * this ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > &
NRMat < complex < double > > : : operator - = ( const NRSMat < complex < double > > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat-=SMat " ) ;
# endif
const complex < double > * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
cblas_zaxpy ( i + 1 , ( void * ) & CMONE , ( void * ) p , 1 , ( void * ) ( * this ) [ i ] , 1 ) ;
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
cblas_zaxpy ( i , ( void * ) & CMONE , ( void * ) p , 1 , ( void * ) ( ( * this ) [ i ] + i ) , nn ) ;
p + = i + 1 ;
}
return * this ;
}
2004-03-17 06:34:59 +01:00
//and for general type
template < typename T >
NRMat < T > & NRMat < T > : : operator - = ( const NRSMat < T > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat+=SMat " ) ;
# endif
const T * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
for ( int j = 0 ; j < i + 1 ; + + j ) * ( ( * this ) [ i ] + j ) - = p [ j ] ;
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
for ( int j = 0 ; j < i ; + + j ) * ( ( * this ) [ i ] + i + nn * j ) - = p [ j ] ;
p + = i + 1 ;
}
return * this ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Mat.Mat - scalar product
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const double NRMat < double > : : dot ( const NRMat < double > & rhs ) const
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm ) laerror ( " Mat.Mat incompatible matrices " ) ;
# endif
return cblas_ddot ( nn * mm , ( * this ) [ 0 ] , 1 , rhs [ 0 ] , 1 ) ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const complex < double >
NRMat < complex < double > > : : dot ( const NRMat < complex < double > > & rhs ) const
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm ) laerror ( " Mat.Mat incompatible matrices " ) ;
# endif
complex < double > dot ;
cblas_zdotc_sub ( nn * mm , ( void * ) ( * this ) [ 0 ] , 1 , ( void * ) rhs [ 0 ] , 1 ,
( void * ) ( & dot ) ) ;
return dot ;
}
// Mat * Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < double > NRMat < double > : : operator * ( const NRMat < double > & rhs ) const
{
# ifdef DEBUG
if ( mm ! = rhs . nn ) laerror ( " product of incompatible matrices " ) ;
2006-09-19 17:59:49 +02:00
if ( rhs . mm < = 0 ) laerror ( " illegal matrix dimension in gemm " ) ;
2004-03-17 04:07:21 +01:00
# endif
NRMat < double > result ( nn , rhs . mm ) ;
cblas_dgemm ( CblasRowMajor , CblasNoTrans , CblasNoTrans , nn , rhs . mm , mm , 1.0 ,
* this , mm , rhs , rhs . mm , 0.0 , result , rhs . mm ) ;
return result ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < complex < double > >
NRMat < complex < double > > : : operator * ( const NRMat < complex < double > > & rhs ) const
{
# ifdef DEBUG
if ( mm ! = rhs . nn ) laerror ( " product of incompatible matrices " ) ;
# endif
NRMat < complex < double > > result ( nn , rhs . mm ) ;
cblas_zgemm ( CblasRowMajor , CblasNoTrans , CblasNoTrans , nn , rhs . mm , mm ,
( const void * ) ( & CONE ) , ( const void * ) ( * this ) [ 0 ] , mm , ( const void * ) rhs [ 0 ] ,
rhs . mm , ( const void * ) ( & CZERO ) , ( void * ) result [ 0 ] , rhs . mm ) ;
return result ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Multiply by diagonal from L
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < double > : : diagmultl ( const NRVec < double > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . size ( ) ) laerror ( " incompatible matrix dimension in diagmultl " ) ;
# endif
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) cblas_dscal ( mm , rhs [ i ] , ( * this ) [ i ] , 1 ) ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < complex < double > > : : diagmultl ( const NRVec < complex < double > > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . size ( ) ) laerror ( " incompatible matrix dimension in diagmultl " ) ;
# endif
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) cblas_zscal ( mm , & rhs [ i ] , ( * this ) [ i ] , 1 ) ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Multiply by diagonal from R
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < double > : : diagmultr ( const NRVec < double > & rhs )
{
# ifdef DEBUG
if ( mm ! = rhs . size ( ) ) laerror ( " incompatible matrix dimension in diagmultr " ) ;
# endif
copyonwrite ( ) ;
2006-10-21 17:32:53 +02:00
for ( int i = 0 ; i < mm ; i + + ) cblas_dscal ( nn , rhs [ i ] , & ( * this ) ( 0 , i ) , mm ) ;
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < complex < double > > : : diagmultr ( const NRVec < complex < double > > & rhs )
{
# ifdef DEBUG
if ( mm ! = rhs . size ( ) ) laerror ( " incompatible matrix dimension in diagmultl " ) ;
# endif
copyonwrite ( ) ;
2006-10-21 17:32:53 +02:00
for ( int i = 0 ; i < mm ; i + + ) cblas_zscal ( nn , & rhs [ i ] , & ( * this ) ( 0 , i ) , mm ) ;
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Mat * Smat, decomposed to nn x Vec * Smat
2006-08-16 23:43:45 +02:00
//NOTE: dsymm is not appropriate as it works on UNPACKED symmetric matrix
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < double >
NRMat < double > : : operator * ( const NRSMat < double > & rhs ) const
{
# ifdef DEBUG
if ( mm ! = rhs . nrows ( ) ) laerror ( " incompatible dimension in Mat*SMat " ) ;
# endif
NRMat < double > result ( nn , rhs . ncols ( ) ) ;
for ( int i = 0 ; i < nn ; i + + )
cblas_dspmv ( CblasRowMajor , CblasLower , mm , 1.0 , & rhs [ 0 ] ,
( * this ) [ i ] , 1 , 0.0 , result [ i ] , 1 ) ;
return result ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < complex < double > >
NRMat < complex < double > > : : operator * ( const NRSMat < complex < double > > & rhs ) const
{
# ifdef DEBUG
if ( mm ! = rhs . nrows ( ) ) laerror ( " incompatible dimension in Mat*SMat " ) ;
# endif
NRMat < complex < double > > result ( nn , rhs . ncols ( ) ) ;
for ( int i = 0 ; i < nn ; i + + )
cblas_zhpmv ( CblasRowMajor , CblasLower , mm , ( void * ) & CONE , ( void * ) & rhs [ 0 ] ,
( void * ) ( * this ) [ i ] , 1 , ( void * ) & CZERO , ( void * ) result [ i ] , 1 ) ;
return result ;
}
2006-08-15 22:10:08 +02:00
2004-03-17 04:07:21 +01:00
// sum of rows
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRVec < double > NRMat < double > : : rsum ( ) const
{
NRVec < double > result ( mm ) ;
for ( int i = 0 ; i < mm ; i + + ) result [ i ] = cblas_dasum ( nn , ( * this ) [ 0 ] + i , mm ) ;
return result ;
}
// sum of columns
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRVec < double > NRMat < double > : : csum ( ) const
{
NRVec < double > result ( nn ) ;
for ( int i = 0 ; i < nn ; i + + ) result [ i ] = cblas_dasum ( mm , ( * this ) [ i ] , 1 ) ;
return result ;
}
// complex conjugate of Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : conjugateme ( ) { return * this ; }
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > & NRMat < complex < double > > : : conjugateme ( )
{
copyonwrite ( ) ;
cblas_dscal ( mm * nn , - 1.0 , ( double * ) ( ( * this ) [ 0 ] ) + 1 , 2 ) ;
return * this ;
}
// transpose and optionally conjugate
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < double > NRMat < double > : : transpose ( bool conj ) const
{
NRMat < double > result ( mm , nn ) ;
for ( int i = 0 ; i < nn ; i + + ) cblas_dcopy ( mm , ( * this ) [ i ] , 1 , result [ 0 ] + i , nn ) ;
return result ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < complex < double > >
NRMat < complex < double > > : : transpose ( bool conj ) const
{
NRMat < complex < double > > result ( mm , nn ) ;
for ( int i = 0 ; i < nn ; i + + )
cblas_zcopy ( mm , ( void * ) ( * this ) [ i ] , 1 , ( void * ) ( result [ 0 ] + i ) , nn ) ;
if ( conj ) cblas_dscal ( mm * nn , - 1.0 , ( double * ) ( result [ 0 ] ) + 1 , 2 ) ;
return result ;
}
// gemm : this = alpha*op( A )*op( B ) + beta*this
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < double > : : gemm ( const double & beta , const NRMat < double > & a ,
const char transa , const NRMat < double > & b , const char transb ,
const double & alpha )
{
int k ( transa = = ' n ' ? a . mm : a . nn ) ;
# ifdef DEBUG
2006-04-06 23:45:51 +02:00
int l ( transa = = ' n ' ? a . nn : a . mm ) ;
int kk ( transb = = ' n ' ? b . nn : b . mm ) ;
int ll ( transb = = ' n ' ? b . mm : b . nn ) ;
2004-03-17 04:07:21 +01:00
if ( l ! = nn | | ll ! = mm | | k ! = kk ) laerror ( " incompatible matrices in Mat:gemm() " ) ;
2006-09-19 17:59:49 +02:00
if ( b . mm < = 0 | | mm < = 0 ) laerror ( " illegal matrix dimension in gemm " ) ;
2004-03-17 04:07:21 +01:00
# endif
if ( alpha = = 0.0 & & beta = = 1.0 ) return ;
copyonwrite ( ) ;
cblas_dgemm ( CblasRowMajor , ( transa = = ' n ' ? CblasNoTrans : CblasTrans ) ,
( transb = = ' n ' ? CblasNoTrans : CblasTrans ) , nn , mm , k , alpha , a ,
a . mm , b , b . mm , beta , * this , mm ) ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < complex < double > > : : gemm ( const complex < double > & beta ,
const NRMat < complex < double > > & a , const char transa ,
const NRMat < complex < double > > & b , const char transb ,
const complex < double > & alpha )
{
int k ( transa = = ' n ' ? a . mm : a . nn ) ;
# ifdef DEBUG
2006-04-06 23:45:51 +02:00
int l ( transa = = ' n ' ? a . nn : a . mm ) ;
int kk ( transb = = ' n ' ? b . nn : b . mm ) ;
int ll ( transb = = ' n ' ? b . mm : b . nn ) ;
2004-03-17 04:07:21 +01:00
if ( l ! = nn | | ll ! = mm | | k ! = kk ) laerror ( " incompatible matrices in Mat:gemm() " ) ;
# endif
if ( alpha = = CZERO & & beta = = CONE ) return ;
copyonwrite ( ) ;
cblas_zgemm ( CblasRowMajor ,
( transa = = ' n ' ? CblasNoTrans : ( transa = = ' c ' ? CblasConjTrans : CblasTrans ) ) ,
( transb = = ' n ' ? CblasNoTrans : ( transa = = ' c ' ? CblasConjTrans : CblasTrans ) ) ,
nn , mm , k , & alpha , a , a . mm , b , b . mm , & beta , * this , mm ) ;
}
// norm of Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const double NRMat < double > : : norm ( const double scalar ) const
{
if ( ! scalar ) return cblas_dnrm2 ( nn * mm , ( * this ) [ 0 ] , 1 ) ;
double sum = 0 ;
for ( int i = 0 ; i < nn ; i + + )
for ( int j = 0 ; j < mm ; j + + ) {
register double tmp ;
# ifdef MATPTR
tmp = v [ i ] [ j ] ;
# else
tmp = v [ i * mm + j ] ;
# endif
if ( i = = j ) tmp - = scalar ;
sum + = tmp * tmp ;
}
return sqrt ( sum ) ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const double NRMat < complex < double > > : : norm ( const complex < double > scalar ) const
{
if ( scalar = = CZERO ) return cblas_dznrm2 ( nn * mm , ( * this ) [ 0 ] , 1 ) ;
double sum = 0 ;
for ( int i = 0 ; i < nn ; i + + )
for ( int j = 0 ; j < mm ; j + + ) {
register complex < double > tmp ;
# ifdef MATPTR
tmp = v [ i ] [ j ] ;
# else
tmp = v [ i * mm + j ] ;
# endif
if ( i = = j ) tmp - = scalar ;
sum + = tmp . real ( ) * tmp . real ( ) + tmp . imag ( ) * tmp . imag ( ) ;
}
return sqrt ( sum ) ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// axpy: this = a * Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < double > : : axpy ( const double alpha , const NRMat < double > & mat )
{
# ifdef DEBUG
if ( nn ! = mat . nn | | mm ! = mat . mm ) laerror ( " daxpy of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
cblas_daxpy ( nn * mm , alpha , mat , 1 , * this , 1 ) ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < complex < double > > : : axpy ( const complex < double > alpha ,
const NRMat < complex < double > > & mat )
{
# ifdef DEBUG
if ( nn ! = mat . nn | | mm ! = mat . mm ) laerror ( " zaxpy of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
cblas_zaxpy ( nn * mm , ( void * ) & alpha , mat , 1 , ( void * ) ( * this ) [ 0 ] , 1 ) ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// trace of Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const double NRMat < double > : : trace ( ) const
{
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " no-square matrix in Mat::trace() " ) ;
# endif
return cblas_dasum ( nn , ( * this ) [ 0 ] , nn + 1 ) ;
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const complex < double > NRMat < complex < double > > : : trace ( ) const
{
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " no-square matrix in Mat::trace() " ) ;
# endif
register complex < double > sum = CZERO ;
for ( int i = 0 ; i < nn * nn ; i + = ( nn + 1 ) )
# ifdef MATPTR
sum + = v [ 0 ] [ i ] ;
# else
sum + = v [ i ] ;
# endif
return sum ;
}
2005-02-02 15:49:33 +01:00
//get diagonal; for compatibility with large matrices do not return newly created object
//for non-square get diagonal of A^TA, will be used as preconditioner
2005-12-08 13:06:23 +01:00
template < >
2006-04-06 23:45:51 +02:00
const double * NRMat < double > : : diagonalof ( NRVec < double > & r , const bool divide , bool cache ) const
2005-02-02 15:49:33 +01:00
{
# ifdef DEBUG
if ( r . size ( ) ! = nn ) laerror ( " diagonalof() incompatible vector " ) ;
# endif
2005-02-04 15:31:42 +01:00
double a ;
r . copyonwrite ( ) ;
2005-02-02 15:49:33 +01:00
if ( nn = = mm )
{
# ifdef MATPTR
2005-02-04 15:31:42 +01:00
if ( divide ) for ( int i = 0 ; i < nn ; i + + ) if ( ( a = v [ i ] [ i ] ) ) r [ i ] / = a ;
else for ( int i = 0 ; i < nn ; i + + ) r [ i ] = v [ i ] [ i ] ;
2005-02-02 15:49:33 +01:00
# else
2005-02-04 15:31:42 +01:00
if ( divide ) { int i , j ; for ( i = j = 0 ; j < nn ; + + j , i + = nn + 1 ) if ( ( a = v [ i ] ) ) r [ j ] / = a ; }
else { int i , j ; for ( i = j = 0 ; j < nn ; + + j , i + = nn + 1 ) r [ j ] = v [ i ] ; }
2005-02-02 15:49:33 +01:00
# endif
}
else //non-square
{
for ( int i = 0 ; i < mm ; i + + )
2005-02-04 15:31:42 +01:00
{
2005-02-02 15:49:33 +01:00
# ifdef MATPTR
2005-02-04 15:31:42 +01:00
a = cblas_ddot ( nn , v [ 0 ] + i , mm , v [ 0 ] + i , mm ) ;
2005-02-02 15:49:33 +01:00
# else
2005-02-04 15:31:42 +01:00
a = cblas_ddot ( nn , v + i , mm , v + i , mm ) ;
2005-02-02 15:49:33 +01:00
# endif
2005-02-04 15:31:42 +01:00
if ( divide ) { if ( a ) r [ i ] / = a ; }
else r [ i ] = a ;
}
2005-02-02 15:49:33 +01:00
}
2006-04-06 23:45:51 +02:00
return divide ? NULL : & r [ 0 ] ;
2005-02-02 15:49:33 +01:00
}
2004-03-17 04:07:21 +01:00
2006-09-10 22:06:44 +02:00
//////////////////////////////////////////////////////////////////////////////
//// forced instantization in the corresponding object file
template class NRMat < double > ;
template class NRMat < complex < double > > ;
template class NRMat < int > ;
template class NRMat < short > ;
template class NRMat < char > ;
template class NRMat < unsigned char > ;
template class NRMat < unsigned int > ;
template class NRMat < unsigned long > ;