2008-02-26 14:55:23 +01:00
/*
LA : linear algebra C + + interface library
Copyright ( C ) 2008 Jiri Pittner < jiri . pittner @ jh - inst . cas . cz > or < jiri @ pittnerovi . com >
complex versions written by Roman Curik < roman . curik @ jh - inst . cas . cz >
This program is free software : you can redistribute it and / or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation , either version 3 of the License , or
( at your option ) any later version .
This program is distributed in the hope that it will be useful ,
but WITHOUT ANY WARRANTY ; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE . See the
GNU General Public License for more details .
You should have received a copy of the GNU General Public License
along with this program . If not , see < http : //www.gnu.org/licenses/>.
*/
2004-03-17 04:07:21 +01:00
# include "mat.h"
2005-02-14 01:10:07 +01:00
# include <stdlib.h>
2010-01-17 21:28:38 +01:00
# include <stdio.h>
2005-02-14 01:10:07 +01:00
# include <sys/types.h>
# include <sys/stat.h>
# include <fcntl.h>
2005-12-08 13:06:23 +01:00
# include <errno.h>
2005-02-14 01:10:07 +01:00
extern " C " {
extern ssize_t read ( int , void * , size_t ) ;
extern ssize_t write ( int , const void * , size_t ) ;
}
2004-03-17 04:07:21 +01:00
// TODO :
//
2009-11-12 22:01:19 +01:00
namespace LA {
2004-03-17 04:07:21 +01:00
/*
* Templates first , specializations for BLAS next
2006-04-01 06:48:01 +02:00
*/
2007-06-22 16:24:55 +02:00
//direct sum
template < typename T >
const NRMat < T > NRMat < T > : : oplus ( const NRMat < T > & rhs ) const
{
2007-06-22 16:46:03 +02:00
if ( nn = = 0 & & mm = = 0 ) return rhs ;
if ( rhs . nn = = 0 & & rhs . mm = = 0 ) return * this ;
2007-06-22 16:24:55 +02:00
NRMat < T > r ( ( T ) 0 , nn + rhs . nn , mm + rhs . mm ) ;
# ifdef oldversion
int i , j ;
for ( i = 0 ; i < nn ; i + + ) for ( j = 0 ; j < mm ; j + + ) r ( i , j ) = ( * this ) ( i , j ) ;
for ( i = 0 ; i < nn ; i + + ) for ( j = mm ; j < mm + rhs . mm ; j + + ) r ( i , j ) = ( T ) 0 ;
for ( i = nn ; i < nn + rhs . nn ; i + + ) for ( j = 0 ; j < mm ; j + + ) r ( i , j ) = ( T ) 0 ;
for ( i = nn ; i < nn + rhs . nn ; i + + ) for ( j = mm ; j < mm + rhs . mm ; j + + ) r ( i , j ) = rhs ( i - nn , j - mm ) ;
# else
r . storesubmatrix ( 0 , 0 , * this ) ;
r . storesubmatrix ( nn , mm , rhs ) ;
# endif
return r ;
}
//direct product
template < typename T >
2007-06-23 23:09:39 +02:00
const NRMat < T > NRMat < T > : : otimes ( const NRMat < T > & rhs , bool reversecolumns ) const
2007-06-22 16:24:55 +02:00
{
2007-06-22 16:46:03 +02:00
if ( nn = = 0 & & mm = = 0 ) return * this ;
if ( rhs . nn = = 0 & & rhs . mm = = 0 ) return rhs ;
2007-06-22 16:24:55 +02:00
NRMat < T > r ( ( T ) 0 , nn * rhs . nn , mm * rhs . mm ) ;
int i , j , k , l ;
2007-06-23 23:09:39 +02:00
if ( reversecolumns )
{
for ( i = 0 ; i < nn ; i + + ) for ( j = 0 ; j < mm ; j + + )
{
T c = ( * this ) ( i , j ) ;
for ( k = 0 ; k < rhs . mm ; k + + ) for ( l = 0 ; l < rhs . mm ; l + + )
r ( i * rhs . nn + k , l * nn + j ) = c * rhs ( k , l ) ;
}
}
else
{
2007-06-22 16:24:55 +02:00
for ( i = 0 ; i < nn ; i + + ) for ( j = 0 ; j < mm ; j + + )
{
T c = ( * this ) ( i , j ) ;
for ( k = 0 ; k < rhs . mm ; k + + ) for ( l = 0 ; l < rhs . mm ; l + + )
r ( i * rhs . nn + k , j * rhs . nn + l ) = c * rhs ( k , l ) ;
}
2007-06-23 23:09:39 +02:00
}
2007-06-22 16:24:55 +02:00
return r ;
}
2006-04-01 06:48:01 +02:00
//row of
template < typename T >
2006-09-13 23:29:28 +02:00
const NRVec < T > NRMat < T > : : row ( const int i , int l ) const
2006-04-01 06:48:01 +02:00
{
# ifdef DEBUG
if ( i < 0 | | i > = nn ) laerror ( " illegal index in row() " ) ;
# endif
2006-09-13 23:29:28 +02:00
if ( l < 0 ) l = mm ;
NRVec < T > r ( l ) ;
2006-04-01 06:48:01 +02:00
LA_traits < T > : : copy ( & r [ 0 ] ,
# ifdef MATPTR
v [ i ]
# else
2006-09-13 23:29:28 +02:00
v + i * l
2006-04-01 06:48:01 +02:00
# endif
2006-09-13 23:29:28 +02:00
, l ) ;
2006-04-01 06:48:01 +02:00
return r ;
}
2004-03-17 04:07:21 +01:00
2005-02-14 01:10:07 +01:00
//raw I/O
template < typename T >
2005-09-11 22:04:24 +02:00
void NRMat < T > : : put ( int fd , bool dim , bool transp ) const
2005-02-14 01:10:07 +01:00
{
2010-06-25 17:28:19 +02:00
# ifdef CUDALA
if ( location ! = cpu )
{
NRMat < T > tmp = * this ;
tmp . moveto ( cpu ) ;
tmp . put ( fd , dim , transp ) ;
return ;
}
# endif
2005-02-14 01:10:07 +01:00
errno = 0 ;
if ( dim )
{
2005-09-11 22:04:24 +02:00
if ( sizeof ( int ) ! = write ( fd , & ( transp ? mm : nn ) , sizeof ( int ) ) ) laerror ( " cannot write " ) ;
if ( sizeof ( int ) ! = write ( fd , & ( transp ? nn : mm ) , sizeof ( int ) ) ) laerror ( " cannot write " ) ;
2005-02-14 01:10:07 +01:00
}
2005-09-11 22:04:24 +02:00
if ( transp ) //not particularly efficient
{
for ( int j = 0 ; j < mm ; + + j )
for ( int i = 0 ; i < nn ; + + i )
LA_traits < T > : : put ( fd ,
# ifdef MATPTR
v [ i ] [ j ]
# else
v [ i * mm + j ]
# endif
, dim , transp ) ;
}
else LA_traits < T > : : multiput ( nn * mm , fd ,
2005-02-14 01:10:07 +01:00
# ifdef MATPTR
v [ 0 ]
# else
v
# endif
, dim ) ;
}
template < typename T >
2005-09-11 22:04:24 +02:00
void NRMat < T > : : get ( int fd , bool dim , bool transp )
2005-02-14 01:10:07 +01:00
{
2010-06-25 17:28:19 +02:00
# ifdef CUDALA
if ( location ! = cpu )
{
NRMat < T > tmp ;
tmp . moveto ( cpu ) ;
tmp . get ( fd , dim , transp ) ;
tmp . moveto ( location ) ;
* this = tmp ;
return ;
}
# endif
2005-02-14 01:10:07 +01:00
int nn0 , mm0 ;
errno = 0 ;
if ( dim )
{
if ( sizeof ( int ) ! = read ( fd , & nn0 , sizeof ( int ) ) ) laerror ( " cannot read " ) ;
if ( sizeof ( int ) ! = read ( fd , & mm0 , sizeof ( int ) ) ) laerror ( " cannot read " ) ;
2005-09-11 22:04:24 +02:00
if ( transp ) resize ( mm0 , nn0 ) ; else resize ( nn0 , mm0 ) ;
2005-02-14 01:10:07 +01:00
}
else
copyonwrite ( ) ;
2005-09-11 22:04:24 +02:00
if ( transp ) //not particularly efficient
{
for ( int j = 0 ; j < mm ; + + j )
for ( int i = 0 ; i < nn ; + + i )
LA_traits < T > : : get ( fd ,
# ifdef MATPTR
v [ i ] [ j ]
# else
v [ i * mm + j ]
# endif
, dim , transp ) ;
}
else LA_traits < T > : : multiget ( nn * mm , fd ,
2005-02-14 01:10:07 +01:00
# ifdef MATPTR
v [ 0 ]
# else
v
# endif
, dim ) ;
}
2004-03-17 04:07:21 +01:00
// Assign diagonal
2010-06-25 17:28:19 +02:00
template < >
NRMat < double > & NRMat < double > : : operator = ( const double & a )
{
copyonwrite ( ) ;
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " RMat.operator=scalar on non-square matrix " ) ;
# endif
# ifdef CUDALA
if ( location = = cpu )
{
# endif
# ifdef MATPTR
memset ( v [ 0 ] , 0 , nn * nn * sizeof ( double ) ) ;
for ( int i = 0 ; i < nn ; i + + ) v [ i ] [ i ] = a ;
# else
double n = 0. ;
cblas_dcopy ( nn * nn , & n , 0 , v , 1 ) ;
cblas_dcopy ( nn , & a , 0 , v , nn + 1 ) ;
# endif
# ifdef CUDALA
}
else
{
double * d = gpuputdouble ( 0. ) ;
cublasDcopy ( nn * nn , d , 0 , v , 1 ) ;
gpufree ( d ) ;
d = gpuputdouble ( a ) ;
cublasDcopy ( nn , d , 0 , v , nn + 1 ) ;
gpufree ( d ) ;
}
# endif
return * this ;
}
2004-03-17 04:07:21 +01:00
template < typename T >
NRMat < T > & NRMat < T > : : operator = ( const T & a )
{
copyonwrite ( ) ;
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " RMat.operator=scalar on non-square matrix " ) ;
# endif
# ifdef MATPTR
2006-09-04 22:12:34 +02:00
memset ( v [ 0 ] , 0 , nn * nn * sizeof ( T ) ) ;
2004-03-17 04:07:21 +01:00
for ( int i = 0 ; i < nn ; i + + ) v [ i ] [ i ] = a ;
# else
2006-09-04 22:12:34 +02:00
memset ( v , 0 , nn * nn * sizeof ( T ) ) ;
2004-03-17 04:07:21 +01:00
for ( int i = 0 ; i < nn * nn ; i + = nn + 1 ) v [ i ] = a ;
# endif
return * this ;
}
2010-06-25 17:28:19 +02:00
template < >
NRMat < double > & NRMat < double > : : operator + = ( const double & a )
{
copyonwrite ( ) ;
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " Mat.operator+=scalar on non-square matrix " ) ;
# endif
# ifdef CUDALA
if ( location = = cpu )
{
# endif
# ifdef MATPTR
for ( int i = 0 ; i < nn ; i + + ) v [ i ] [ i ] + = a ;
# else
cblas_daxpy ( nn , 1.0 , & a , 0 , * this , nn + 1 ) ;
# endif
# ifdef CUDALA
}
else
{
double * d = gpuputdouble ( a ) ;
cublasDaxpy ( nn , 1.0 , d , 0 , * this , nn + 1 ) ;
gpufree ( d ) ;
}
# endif
return * this ;
}
template < >
NRMat < double > & NRMat < double > : : operator - = ( const double & a )
{
copyonwrite ( ) ;
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " Mat.operator+=scalar on non-square matrix " ) ;
# endif
# ifdef CUDALA
if ( location = = cpu )
{
# endif
# ifdef MATPTR
for ( int i = 0 ; i < nn ; i + + ) v [ i ] [ i ] - = a ;
# else
cblas_daxpy ( nn , - 1.0 , & a , 0 , * this , nn + 1 ) ;
# endif
# ifdef CUDALA
}
else
{
double * d = gpuputdouble ( a ) ;
cublasDaxpy ( nn , - 1.0 , d , 0 , * this , nn + 1 ) ;
gpufree ( d ) ;
}
# endif
return * this ;
}
2004-03-17 04:07:21 +01:00
2005-02-01 00:08:03 +01:00
2004-03-17 04:07:21 +01:00
// M += a
template < typename T >
NRMat < T > & NRMat < T > : : operator + = ( const T & a )
{
copyonwrite ( ) ;
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " Mat.operator+=scalar on non-square matrix " ) ;
# endif
# ifdef MATPTR
for ( int i = 0 ; i < nn ; i + + ) v [ i ] [ i ] + = a ;
# else
for ( int i = 0 ; i < nn * nn ; i + = nn + 1 ) v [ i ] + = a ;
# endif
return * this ;
}
// M -= a
template < typename T >
NRMat < T > & NRMat < T > : : operator - = ( const T & a )
{
copyonwrite ( ) ;
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " Mat.operator-=scalar on non-square matrix " ) ;
# endif
# ifdef MATPTR
for ( int i = 0 ; i < nn ; i + + ) v [ i ] [ i ] - = a ;
# else
for ( int i = 0 ; i < nn * nn ; i + = nn + 1 ) v [ i ] - = a ;
# endif
return * this ;
}
2010-06-25 17:28:19 +02:00
template < >
const NRMat < double > NRMat < double > : : operator - ( ) const
{
NRMat < double > result ( nn , mm ) ;
# ifdef CUDALA
if ( location = = cpu )
{
# endif
# ifdef MATPTR
for ( int i = 0 ; i < nn * mm ; i + + ) result . v [ 0 ] [ i ] = - v [ 0 ] [ i ] ;
# else
cblas_dscal ( nn * mm , - 1. , v , 1 ) ;
# endif
# ifdef CUDALA
}
else
{
cublasDscal ( nn * mm , - 1. , v , 1 ) ;
}
# endif
return result ;
}
2004-03-17 04:07:21 +01:00
// unary minus
template < typename T >
const NRMat < T > NRMat < T > : : operator - ( ) const
{
NRMat < T > result ( nn , mm ) ;
# ifdef MATPTR
for ( int i = 0 ; i < nn * mm ; i + + ) result . v [ 0 ] [ i ] = - v [ 0 ] [ i ] ;
# else
for ( int i = 0 ; i < nn * mm ; i + + ) result . v [ i ] = - v [ i ] ;
# endif
return result ;
}
2010-06-25 17:28:19 +02:00
2004-03-17 04:07:21 +01:00
// direct sum
template < typename T >
const NRMat < T > NRMat < T > : : operator & ( const NRMat < T > & b ) const
{
NRMat < T > result ( ( T ) 0 , nn + b . nn , mm + b . mm ) ;
for ( int i = 0 ; i < nn ; i + + ) memcpy ( result [ i ] , ( * this ) [ i ] , sizeof ( T ) * mm ) ;
for ( int i = 0 ; i < b . nn ; i + + ) memcpy ( result [ nn + i ] + nn , b [ i ] , sizeof ( T ) * b . mm ) ;
return result ;
}
// direct product
template < typename T >
const NRMat < T > NRMat < T > : : operator | ( const NRMat < T > & b ) const
{
NRMat < T > result ( nn * b . nn , mm * b . mm ) ;
for ( int i = 0 ; i < nn ; i + + )
for ( int j = 0 ; j < mm ; j + + )
for ( int k = 0 ; k < b . nn ; k + + )
for ( int l = 0 ; l < b . mm ; l + + )
result [ i * b . nn + k ] [ j * b . mm + l ] = ( * this ) [ i ] [ j ] * b [ k ] [ l ] ;
return result ;
}
// sum of columns
template < typename T >
const NRVec < T > NRMat < T > : : csum ( ) const
{
NRVec < T > result ( nn ) ;
T sum ;
for ( int i = 0 ; i < nn ; i + + ) {
sum = ( T ) 0 ;
for ( int j = 0 ; j < mm ; j + + ) sum + = ( * this ) [ i ] [ j ] ;
result [ i ] = sum ;
}
return result ;
}
// sum of rows
template < typename T >
const NRVec < T > NRMat < T > : : rsum ( ) const
{
NRVec < T > result ( nn ) ;
T sum ;
for ( int i = 0 ; i < mm ; i + + ) {
sum = ( T ) 0 ;
for ( int j = 0 ; j < nn ; j + + ) sum + = ( * this ) [ j ] [ i ] ;
result [ i ] = sum ;
}
return result ;
}
2005-09-11 22:04:24 +02:00
//block submatrix
template < typename T >
const NRMat < T > NRMat < T > : : submatrix ( const int fromrow , const int torow , const int fromcol , const int tocol ) const
{
# ifdef DEBUG
if ( fromrow < 0 | | fromrow > = nn | | torow < 0 | | torow > = nn | | fromcol < 0 | | fromcol > = mm | | tocol < 0 | | tocol > = mm | | fromrow > torow | | fromcol > tocol ) laerror ( " bad indices in submatrix " ) ;
# endif
int n = torow - fromrow + 1 ;
int m = tocol - fromcol + 1 ;
NRMat < T > r ( n , m ) ;
for ( int i = fromrow ; i < = torow ; + + i )
# ifdef MATPTR
memcpy ( r . v [ i - fromrow ] , v [ i ] + fromcol , m * sizeof ( T ) ) ;
# else
memcpy ( r . v + ( i - fromrow ) * m , v + i * mm + fromcol , m * sizeof ( T ) ) ;
# endif
return r ;
}
2004-03-17 04:07:21 +01:00
2006-10-21 22:14:13 +02:00
template < typename T >
void NRMat < T > : : storesubmatrix ( const int fromrow , const int fromcol , const NRMat & rhs )
{
int tocol = fromcol + rhs . ncols ( ) - 1 ;
int torow = fromrow + rhs . nrows ( ) - 1 ;
# ifdef DEBUG
if ( fromrow < 0 | | fromrow > = nn | | torow > = nn | | fromcol < 0 | | fromcol > = mm | | tocol > = mm ) laerror ( " bad indices in storesubmatrix " ) ;
# endif
int m = tocol - fromcol + 1 ;
for ( int i = fromrow ; i < = torow ; + + i )
# ifdef MATPTR
memcpy ( v [ i ] + fromcol , rhs . v [ i - fromrow ] , m * sizeof ( T ) ) ;
# else
memcpy ( v + i * mm + fromcol , rhs . v + ( i - fromrow ) * m , m * sizeof ( T ) ) ;
# endif
}
2004-03-17 04:07:21 +01:00
// transpose Mat
template < typename T >
2005-02-18 23:08:15 +01:00
NRMat < T > & NRMat < T > : : transposeme ( int n )
2004-03-17 04:07:21 +01:00
{
2005-02-18 23:08:15 +01:00
if ( n = = 0 ) n = nn ;
2004-03-17 04:07:21 +01:00
# ifdef DEBUG
2005-02-18 23:08:15 +01:00
if ( n = = nn & & nn ! = mm | | n > mm | | n > nn ) laerror ( " transpose of non-square Mat " ) ;
2004-03-17 04:07:21 +01:00
# endif
copyonwrite ( ) ;
2005-02-18 23:08:15 +01:00
for ( int i = 1 ; i < n ; i + + )
2004-03-17 04:07:21 +01:00
for ( int j = 0 ; j < i ; j + + ) {
# ifdef MATPTR
T tmp = v [ i ] [ j ] ;
v [ i ] [ j ] = v [ j ] [ i ] ;
v [ j ] [ i ] = tmp ;
# else
register int a ;
register int b ;
a = i * mm + j ;
b = j * mm + i ;
T tmp = v [ a ] ;
v [ a ] = v [ b ] ;
v [ b ] = tmp ;
# endif
}
return * this ;
}
2009-10-08 16:01:15 +02:00
//complex from real
template < >
NRMat < complex < double > > : : NRMat ( const NRMat < double > & rhs , bool imagpart )
: nn ( rhs . nrows ( ) ) , mm ( rhs . ncols ( ) ) , count ( new int ( 1 ) )
{
# ifdef MATPTR
v = new complex < double > * [ n ] ;
v [ 0 ] = new complex < double > [ mm * nn ] ;
for ( int i = 1 ; i < n ; i + + ) v [ i ] = v [ i - 1 ] + m ;
memset ( v [ 0 ] , 0 , nn * mm * sizeof ( complex < double > ) ) ;
cblas_dcopy ( nn * mm , & rhs [ 0 ] [ 0 ] , 1 , ( ( double * ) v [ 0 ] ) + ( imagpart ? 1 : 0 ) , 2 ) ;
# else
v = new complex < double > [ mm * nn ] ;
memset ( v , 0 , nn * mm * sizeof ( complex < double > ) ) ;
cblas_dcopy ( nn * mm , & rhs [ 0 ] [ 0 ] , 1 , ( ( double * ) v ) + ( imagpart ? 1 : 0 ) , 2 ) ;
# endif
}
2004-03-17 04:07:21 +01:00
// Output of Mat
template < typename T >
void NRMat < T > : : fprintf ( FILE * file , const char * format , const int modulo ) const
{
lawritemat ( file , ( const T * ) ( * this ) , nn , mm , format , 2 , modulo , 0 ) ;
}
// Input of Mat
template < typename T >
void NRMat < T > : : fscanf ( FILE * f , const char * format )
{
int n , m ;
2010-01-17 21:28:38 +01:00
if ( : : fscanf ( f , " %d %d " , & n , & m ) ! = 2 )
2004-03-17 04:07:21 +01:00
laerror ( " cannot read matrix dimensions in Mat::fscanf() " ) ;
resize ( n , m ) ;
T * p = * this ;
for ( int i = 0 ; i < n ; i + + )
for ( int j = 0 ; j < n ; j + + )
2010-01-17 21:28:38 +01:00
if ( : : fscanf ( f , format , p + + ) ! = 1 )
2004-03-17 04:07:21 +01:00
laerror ( " cannot read matrix element in Mat::fscanf() " ) ;
}
/*
* BLAS specializations for double and complex < double >
*/
2006-04-01 06:48:01 +02:00
template < >
const NRSMat < double > NRMat < double > : : transposedtimes ( ) const
{
NRSMat < double > r ( mm , mm ) ;
int i , j ;
for ( i = 0 ; i < mm ; + + i ) for ( j = 0 ; j < = i ; + + j )
# ifdef MATPTR
r ( i , j ) = cblas_ddot ( nn , v [ 0 ] + i , mm , v [ 0 ] + j , mm ) ;
# else
r ( i , j ) = cblas_ddot ( nn , v + i , mm , v + j , mm ) ;
# endif
return r ;
}
template < >
const NRSMat < complex < double > > NRMat < complex < double > > : : transposedtimes ( ) const
{
NRSMat < complex < double > > r ( mm , mm ) ;
int i , j ;
for ( i = 0 ; i < mm ; + + i ) for ( j = 0 ; j < = i ; + + j )
# ifdef MATPTR
2009-11-12 22:01:19 +01:00
cblas_zdotc_sub ( nn , v [ 0 ] + i , mm , v [ 0 ] + j , mm , & r ( i , j ) ) ;
2006-04-01 06:48:01 +02:00
# else
2009-11-12 22:01:19 +01:00
cblas_zdotc_sub ( nn , v + i , mm , v + j , mm , & r ( i , j ) ) ;
2006-04-01 06:48:01 +02:00
# endif
return r ;
}
//and for general type
template < typename T >
const NRSMat < T > NRMat < T > : : transposedtimes ( ) const
{
NRSMat < T > r ( mm , mm ) ;
int i , j ;
for ( i = 0 ; i < mm ; + + i ) for ( j = 0 ; j < = i ; + + j )
{
T s = ( T ) 0 ;
for ( int k = 0 ; k < nn ; + + k ) s + = ( * this ) ( k , i ) * ( * this ) ( k , j ) ;
r ( i , j ) = s ;
}
return r ;
}
template < >
const NRSMat < double > NRMat < double > : : timestransposed ( ) const
{
NRSMat < double > r ( nn , nn ) ;
int i , j ;
for ( i = 0 ; i < nn ; + + i ) for ( j = 0 ; j < = i ; + + j )
# ifdef MATPTR
r ( i , j ) = cblas_ddot ( mm , v [ i ] , 1 , v [ j ] , 1 ) ;
# else
r ( i , j ) = cblas_ddot ( mm , v + i * mm , 1 , v + j * mm , 1 ) ;
# endif
return r ;
}
template < >
const NRSMat < complex < double > > NRMat < complex < double > > : : timestransposed ( ) const
{
NRSMat < complex < double > > r ( nn , nn ) ;
int i , j ;
for ( i = 0 ; i < nn ; + + i ) for ( j = 0 ; j < = i ; + + j )
# ifdef MATPTR
2009-11-12 22:01:19 +01:00
cblas_zdotc_sub ( mm , v [ i ] , 1 , v [ j ] , 1 , & r ( i , j ) ) ;
2006-04-01 06:48:01 +02:00
# else
2009-11-12 22:01:19 +01:00
cblas_zdotc_sub ( mm , v + i * mm , 1 , v + j * mm , 1 , & r ( i , j ) ) ;
2006-04-01 06:48:01 +02:00
# endif
return r ;
}
//and for general type
template < typename T >
const NRSMat < T > NRMat < T > : : timestransposed ( ) const
{
NRSMat < T > r ( nn , nn ) ;
int i , j ;
for ( i = 0 ; i < nn ; + + i ) for ( j = 0 ; j < = i ; + + j )
{
T s = ( T ) 0 ;
for ( int k = 0 ; k < mm ; + + k ) s + = ( * this ) ( i , k ) * ( * this ) ( j , k ) ;
r ( i , j ) = s ;
}
return r ;
}
2008-03-03 16:35:37 +01:00
//randomize
template < >
void NRMat < double > : : randomize ( const double & x )
{
for ( int i = 0 ; i < nn ; + + i )
for ( int j = 0 ; j < mm ; + + j )
( * this ) ( i , j ) = x * ( 2. * random ( ) / ( 1. + RAND_MAX ) - 1. ) ;
}
2009-10-08 16:01:15 +02:00
template < >
void NRMat < complex < double > > : : randomize ( const double & x )
{
for ( int i = 0 ; i < nn ; + + i )
for ( int j = 0 ; j < mm ; + + j )
{
( * this ) ( i , j ) . real ( ) = x * ( 2. * random ( ) / ( 1. + RAND_MAX ) - 1. ) ;
( * this ) ( i , j ) . imag ( ) = x * ( 2. * random ( ) / ( 1. + RAND_MAX ) - 1. ) ;
}
}
2006-04-01 06:48:01 +02:00
2004-03-17 04:07:21 +01:00
// Mat *= a
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : operator * = ( const double & a )
{
copyonwrite ( ) ;
2010-06-25 17:28:19 +02:00
# ifdef CUDALA
if ( location = = cpu )
# endif
cblas_dscal ( nn * mm , a , * this , 1 ) ;
# ifdef CUDALA
else cublasDscal ( nn * mm , a , v , 1 ) ;
# endif
2004-03-17 04:07:21 +01:00
return * this ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > &
NRMat < complex < double > > : : operator * = ( const complex < double > & a )
{
copyonwrite ( ) ;
2009-11-12 22:01:19 +01:00
cblas_zscal ( nn * mm , & a , ( * this ) [ 0 ] , 1 ) ;
2004-03-17 04:07:21 +01:00
return * this ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template < typename T >
NRMat < T > & NRMat < T > : : operator * = ( const T & a )
{
2010-06-25 17:28:19 +02:00
NOT_GPU ( * this ) ;
2004-03-17 06:34:59 +01:00
copyonwrite ( ) ;
# ifdef MATPTR
2010-01-17 21:28:38 +01:00
for ( int i = 0 ; i < nn * mm ; i + + ) v [ 0 ] [ i ] * = a ;
2004-03-17 06:34:59 +01:00
# else
2010-01-17 21:28:38 +01:00
for ( int i = 0 ; i < nn * mm ; i + + ) v [ i ] * = a ;
2004-03-17 06:34:59 +01:00
# endif
return * this ;
}
2004-03-17 04:07:21 +01:00
// Mat += Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : operator + = ( const NRMat < double > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat += Mat of incompatible matrices " ) ;
# endif
2010-06-25 17:28:19 +02:00
SAME_LOC ( * this , rhs ) ;
2004-03-17 04:07:21 +01:00
copyonwrite ( ) ;
2010-06-25 17:28:19 +02:00
# ifdef CUDALA
if ( location = = cpu )
# endif
2004-03-17 04:07:21 +01:00
cblas_daxpy ( nn * mm , 1.0 , rhs , 1 , * this , 1 ) ;
2010-06-25 17:28:19 +02:00
# ifdef CUDALA
else
cublasDaxpy ( nn * mm , 1.0 , rhs , 1 , v , 1 ) ;
# endif
2004-03-17 04:07:21 +01:00
return * this ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > &
NRMat < complex < double > > : : operator + = ( const NRMat < complex < double > > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat += Mat of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
2009-11-12 22:01:19 +01:00
cblas_zaxpy ( nn * mm , & CONE , rhs [ 0 ] , 1 , ( * this ) [ 0 ] , 1 ) ;
2004-03-17 04:07:21 +01:00
return * this ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template < typename T >
NRMat < T > & NRMat < T > : : operator + = ( const NRMat < T > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat -= Mat of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
# ifdef MATPTR
2010-01-17 21:28:38 +01:00
for ( int i = 0 ; i < nn * mm ; i + + ) v [ 0 ] [ i ] + = rhs . v [ 0 ] [ i ] ;
2004-03-17 06:34:59 +01:00
# else
2010-01-17 21:28:38 +01:00
for ( int i = 0 ; i < nn * mm ; i + + ) v [ i ] + = rhs . v [ i ] ;
2004-03-17 06:34:59 +01:00
# endif
return * this ;
}
2004-03-17 04:07:21 +01:00
// Mat -= Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : operator - = ( const NRMat < double > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat -= Mat of incompatible matrices " ) ;
# endif
2010-06-25 17:28:19 +02:00
SAME_LOC ( * this , rhs ) ;
2004-03-17 04:07:21 +01:00
copyonwrite ( ) ;
2010-06-25 17:28:19 +02:00
# ifdef CUDALA
if ( location = = cpu )
# endif
2004-03-17 04:07:21 +01:00
cblas_daxpy ( nn * mm , - 1.0 , rhs , 1 , * this , 1 ) ;
2010-06-25 17:28:19 +02:00
# ifdef CUDALA
else
cublasDaxpy ( nn * mm , - 1.0 , rhs , 1 , v , 1 ) ;
# endif
2004-03-17 04:07:21 +01:00
return * this ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > &
NRMat < complex < double > > : : operator - = ( const NRMat < complex < double > > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat -= Mat of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
2009-11-12 22:01:19 +01:00
cblas_zaxpy ( nn * mm , & CMONE , rhs [ 0 ] , 1 , ( * this ) [ 0 ] , 1 ) ;
2004-03-17 04:07:21 +01:00
return * this ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template < typename T >
NRMat < T > & NRMat < T > : : operator - = ( const NRMat < T > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm )
laerror ( " Mat -= Mat of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
# ifdef MATPTR
2010-01-17 21:28:38 +01:00
for ( int i = 0 ; i < nn * mm ; i + + ) v [ 0 ] [ i ] - = rhs . v [ 0 ] [ i ] ;
2004-03-17 06:34:59 +01:00
# else
2010-01-17 21:28:38 +01:00
for ( int i = 0 ; i < nn * mm ; i + + ) v [ i ] - = rhs . v [ i ] ;
2004-03-17 06:34:59 +01:00
# endif
return * this ;
}
2004-03-17 04:07:21 +01:00
// Mat += SMat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : operator + = ( const NRSMat < double > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat+=SMat " ) ;
# endif
const double * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
cblas_daxpy ( i + 1 , 1.0 , p , 1 , ( * this ) [ i ] , 1 ) ;
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
cblas_daxpy ( i , 1.0 , p , 1 , ( * this ) [ 0 ] + i , nn ) ;
p + = i + 1 ;
}
return * this ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > &
NRMat < complex < double > > : : operator + = ( const NRSMat < complex < double > > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat+=SMat " ) ;
# endif
const complex < double > * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
2009-11-12 22:01:19 +01:00
cblas_zaxpy ( i + 1 , & CONE , p , 1 , ( * this ) [ i ] , 1 ) ;
2004-03-17 04:07:21 +01:00
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
2009-11-12 22:01:19 +01:00
cblas_zaxpy ( i , & CONE , p , 1 , ( * this ) [ 0 ] + i , nn ) ;
2004-03-17 04:07:21 +01:00
p + = i + 1 ;
}
return * this ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template < typename T >
NRMat < T > & NRMat < T > : : operator + = ( const NRSMat < T > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat+=SMat " ) ;
# endif
const T * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
for ( int j = 0 ; j < i + 1 ; + + j ) * ( ( * this ) [ i ] + j ) + = p [ j ] ;
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
2009-11-12 22:01:19 +01:00
for ( int j = 0 ; j < i ; + + j ) * ( ( * this ) [ j ] + i ) + = p [ j ] ;
2004-03-17 06:34:59 +01:00
p + = i + 1 ;
}
return * this ;
}
2004-03-17 04:07:21 +01:00
// Mat -= SMat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : operator - = ( const NRSMat < double > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat-=SMat " ) ;
# endif
const double * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
cblas_daxpy ( i + 1 , - 1.0 , p , 1 , ( * this ) [ i ] , 1 ) ;
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
cblas_daxpy ( i , - 1.0 , p , 1 , ( * this ) [ 0 ] + i , nn ) ;
p + = i + 1 ;
}
return * this ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > &
NRMat < complex < double > > : : operator - = ( const NRSMat < complex < double > > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat-=SMat " ) ;
# endif
const complex < double > * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
2009-11-12 22:01:19 +01:00
cblas_zaxpy ( i + 1 , & CMONE , p , 1 , ( * this ) [ i ] , 1 ) ;
2004-03-17 04:07:21 +01:00
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
2009-11-12 22:01:19 +01:00
cblas_zaxpy ( i , & CMONE , p , 1 , ( * this ) [ 0 ] + i , nn ) ;
2004-03-17 04:07:21 +01:00
p + = i + 1 ;
}
return * this ;
}
2004-03-17 06:34:59 +01:00
//and for general type
template < typename T >
NRMat < T > & NRMat < T > : : operator - = ( const NRSMat < T > & rhs )
{
# ifdef DEBUG
if ( nn ! = mm | | nn ! = rhs . nrows ( ) ) laerror ( " incompatible matrix size in Mat+=SMat " ) ;
# endif
const T * p = rhs ;
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) {
for ( int j = 0 ; j < i + 1 ; + + j ) * ( ( * this ) [ i ] + j ) - = p [ j ] ;
p + = i + 1 ;
}
p = rhs ; p + + ;
for ( int i = 1 ; i < nn ; i + + ) {
2009-11-12 22:01:19 +01:00
for ( int j = 0 ; j < i ; + + j ) * ( ( * this ) [ j ] + i ) - = p [ j ] ;
2004-03-17 06:34:59 +01:00
p + = i + 1 ;
}
return * this ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Mat.Mat - scalar product
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const double NRMat < double > : : dot ( const NRMat < double > & rhs ) const
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm ) laerror ( " Mat.Mat incompatible matrices " ) ;
# endif
return cblas_ddot ( nn * mm , ( * this ) [ 0 ] , 1 , rhs [ 0 ] , 1 ) ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const complex < double >
NRMat < complex < double > > : : dot ( const NRMat < complex < double > > & rhs ) const
{
# ifdef DEBUG
if ( nn ! = rhs . nn | | mm ! = rhs . mm ) laerror ( " Mat.Mat incompatible matrices " ) ;
# endif
complex < double > dot ;
2009-11-12 22:01:19 +01:00
cblas_zdotc_sub ( nn * mm , ( * this ) [ 0 ] , 1 , rhs [ 0 ] , 1 ,
& dot ) ;
2004-03-17 04:07:21 +01:00
return dot ;
}
// Mat * Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < double > NRMat < double > : : operator * ( const NRMat < double > & rhs ) const
{
# ifdef DEBUG
if ( mm ! = rhs . nn ) laerror ( " product of incompatible matrices " ) ;
2006-09-19 17:59:49 +02:00
if ( rhs . mm < = 0 ) laerror ( " illegal matrix dimension in gemm " ) ;
2004-03-17 04:07:21 +01:00
# endif
2010-06-25 17:28:19 +02:00
SAME_LOC ( * this , rhs ) ;
NRMat < double > result ( nn , rhs . mm , rhs . getlocation ( ) ) ;
# ifdef CUDALA
if ( location = = cpu )
# endif
2004-03-17 04:07:21 +01:00
cblas_dgemm ( CblasRowMajor , CblasNoTrans , CblasNoTrans , nn , rhs . mm , mm , 1.0 ,
* this , mm , rhs , rhs . mm , 0.0 , result , rhs . mm ) ;
2010-06-25 17:28:19 +02:00
# ifdef CUDALA
else
cublasDgemm ( ' N ' , ' N ' , rhs . mm , nn , mm , 1.0 , rhs , rhs . mm , * this , mm , 0.0 , result , rhs . mm ) ;
# endif
2004-03-17 04:07:21 +01:00
return result ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < complex < double > >
NRMat < complex < double > > : : operator * ( const NRMat < complex < double > > & rhs ) const
{
# ifdef DEBUG
if ( mm ! = rhs . nn ) laerror ( " product of incompatible matrices " ) ;
# endif
NRMat < complex < double > > result ( nn , rhs . mm ) ;
cblas_zgemm ( CblasRowMajor , CblasNoTrans , CblasNoTrans , nn , rhs . mm , mm ,
2009-11-12 22:01:19 +01:00
& CONE , ( * this ) [ 0 ] , mm , rhs [ 0 ] ,
rhs . mm , & CZERO , result [ 0 ] , rhs . mm ) ;
2004-03-17 04:07:21 +01:00
return result ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Multiply by diagonal from L
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < double > : : diagmultl ( const NRVec < double > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . size ( ) ) laerror ( " incompatible matrix dimension in diagmultl " ) ;
# endif
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) cblas_dscal ( mm , rhs [ i ] , ( * this ) [ i ] , 1 ) ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < complex < double > > : : diagmultl ( const NRVec < complex < double > > & rhs )
{
# ifdef DEBUG
if ( nn ! = rhs . size ( ) ) laerror ( " incompatible matrix dimension in diagmultl " ) ;
# endif
copyonwrite ( ) ;
for ( int i = 0 ; i < nn ; i + + ) cblas_zscal ( mm , & rhs [ i ] , ( * this ) [ i ] , 1 ) ;
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Multiply by diagonal from R
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < double > : : diagmultr ( const NRVec < double > & rhs )
{
# ifdef DEBUG
if ( mm ! = rhs . size ( ) ) laerror ( " incompatible matrix dimension in diagmultr " ) ;
# endif
copyonwrite ( ) ;
2006-10-21 17:32:53 +02:00
for ( int i = 0 ; i < mm ; i + + ) cblas_dscal ( nn , rhs [ i ] , & ( * this ) ( 0 , i ) , mm ) ;
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < complex < double > > : : diagmultr ( const NRVec < complex < double > > & rhs )
{
# ifdef DEBUG
if ( mm ! = rhs . size ( ) ) laerror ( " incompatible matrix dimension in diagmultl " ) ;
# endif
copyonwrite ( ) ;
2006-10-21 17:32:53 +02:00
for ( int i = 0 ; i < mm ; i + + ) cblas_zscal ( nn , & rhs [ i ] , & ( * this ) ( 0 , i ) , mm ) ;
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Mat * Smat, decomposed to nn x Vec * Smat
2006-08-16 23:43:45 +02:00
//NOTE: dsymm is not appropriate as it works on UNPACKED symmetric matrix
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < double >
NRMat < double > : : operator * ( const NRSMat < double > & rhs ) const
{
# ifdef DEBUG
if ( mm ! = rhs . nrows ( ) ) laerror ( " incompatible dimension in Mat*SMat " ) ;
# endif
NRMat < double > result ( nn , rhs . ncols ( ) ) ;
for ( int i = 0 ; i < nn ; i + + )
cblas_dspmv ( CblasRowMajor , CblasLower , mm , 1.0 , & rhs [ 0 ] ,
( * this ) [ i ] , 1 , 0.0 , result [ i ] , 1 ) ;
return result ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < complex < double > >
NRMat < complex < double > > : : operator * ( const NRSMat < complex < double > > & rhs ) const
{
# ifdef DEBUG
if ( mm ! = rhs . nrows ( ) ) laerror ( " incompatible dimension in Mat*SMat " ) ;
# endif
NRMat < complex < double > > result ( nn , rhs . ncols ( ) ) ;
for ( int i = 0 ; i < nn ; i + + )
2009-11-12 22:01:19 +01:00
cblas_zhpmv ( CblasRowMajor , CblasLower , mm , & CONE , & rhs [ 0 ] ,
( * this ) [ i ] , 1 , & CZERO , result [ i ] , 1 ) ;
2004-03-17 04:07:21 +01:00
return result ;
}
2006-08-15 22:10:08 +02:00
2004-03-17 04:07:21 +01:00
// complex conjugate of Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < double > & NRMat < double > : : conjugateme ( ) { return * this ; }
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
NRMat < complex < double > > & NRMat < complex < double > > : : conjugateme ( )
{
copyonwrite ( ) ;
cblas_dscal ( mm * nn , - 1.0 , ( double * ) ( ( * this ) [ 0 ] ) + 1 , 2 ) ;
return * this ;
}
// transpose and optionally conjugate
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < double > NRMat < double > : : transpose ( bool conj ) const
{
NRMat < double > result ( mm , nn ) ;
for ( int i = 0 ; i < nn ; i + + ) cblas_dcopy ( mm , ( * this ) [ i ] , 1 , result [ 0 ] + i , nn ) ;
return result ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const NRMat < complex < double > >
NRMat < complex < double > > : : transpose ( bool conj ) const
{
NRMat < complex < double > > result ( mm , nn ) ;
for ( int i = 0 ; i < nn ; i + + )
2009-11-12 22:01:19 +01:00
cblas_zcopy ( mm , ( * this ) [ i ] , 1 , ( result [ 0 ] + i ) , nn ) ;
2004-03-17 04:07:21 +01:00
if ( conj ) cblas_dscal ( mm * nn , - 1.0 , ( double * ) ( result [ 0 ] ) + 1 , 2 ) ;
return result ;
}
// gemm : this = alpha*op( A )*op( B ) + beta*this
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < double > : : gemm ( const double & beta , const NRMat < double > & a ,
const char transa , const NRMat < double > & b , const char transb ,
const double & alpha )
{
int k ( transa = = ' n ' ? a . mm : a . nn ) ;
# ifdef DEBUG
2006-04-06 23:45:51 +02:00
int l ( transa = = ' n ' ? a . nn : a . mm ) ;
int kk ( transb = = ' n ' ? b . nn : b . mm ) ;
int ll ( transb = = ' n ' ? b . mm : b . nn ) ;
2004-03-17 04:07:21 +01:00
if ( l ! = nn | | ll ! = mm | | k ! = kk ) laerror ( " incompatible matrices in Mat:gemm() " ) ;
2006-09-19 17:59:49 +02:00
if ( b . mm < = 0 | | mm < = 0 ) laerror ( " illegal matrix dimension in gemm " ) ;
2004-03-17 04:07:21 +01:00
# endif
2010-06-25 17:28:19 +02:00
SAME_LOC3 ( * this , a , b ) ;
2004-03-17 04:07:21 +01:00
if ( alpha = = 0.0 & & beta = = 1.0 ) return ;
copyonwrite ( ) ;
2010-06-25 17:28:19 +02:00
# ifdef CUDALA
if ( location = = cpu )
# endif
2004-03-17 04:07:21 +01:00
cblas_dgemm ( CblasRowMajor , ( transa = = ' n ' ? CblasNoTrans : CblasTrans ) ,
( transb = = ' n ' ? CblasNoTrans : CblasTrans ) , nn , mm , k , alpha , a ,
a . mm , b , b . mm , beta , * this , mm ) ;
2010-06-25 17:28:19 +02:00
# ifdef CUDALA
else
cublasDgemm ( transb , transa , mm , nn , k , alpha , b , b . mm , a , a . mm , beta , * this , mm ) ;
# endif
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < complex < double > > : : gemm ( const complex < double > & beta ,
const NRMat < complex < double > > & a , const char transa ,
const NRMat < complex < double > > & b , const char transb ,
const complex < double > & alpha )
{
int k ( transa = = ' n ' ? a . mm : a . nn ) ;
# ifdef DEBUG
2006-04-06 23:45:51 +02:00
int l ( transa = = ' n ' ? a . nn : a . mm ) ;
int kk ( transb = = ' n ' ? b . nn : b . mm ) ;
int ll ( transb = = ' n ' ? b . mm : b . nn ) ;
2004-03-17 04:07:21 +01:00
if ( l ! = nn | | ll ! = mm | | k ! = kk ) laerror ( " incompatible matrices in Mat:gemm() " ) ;
# endif
if ( alpha = = CZERO & & beta = = CONE ) return ;
copyonwrite ( ) ;
cblas_zgemm ( CblasRowMajor ,
( transa = = ' n ' ? CblasNoTrans : ( transa = = ' c ' ? CblasConjTrans : CblasTrans ) ) ,
( transb = = ' n ' ? CblasNoTrans : ( transa = = ' c ' ? CblasConjTrans : CblasTrans ) ) ,
nn , mm , k , & alpha , a , a . mm , b , b . mm , & beta , * this , mm ) ;
}
// norm of Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const double NRMat < double > : : norm ( const double scalar ) const
{
2010-06-25 17:28:19 +02:00
if ( ! scalar )
{
# ifdef CUDALA
if ( location = = cpu )
# endif
return cblas_dnrm2 ( nn * mm , ( * this ) [ 0 ] , 1 ) ;
# ifdef CUDALA
else
return cublasDnrm2 ( nn * mm , v , 1 ) ;
# endif
}
NOT_GPU ( * this ) ;
2004-03-17 04:07:21 +01:00
double sum = 0 ;
for ( int i = 0 ; i < nn ; i + + )
for ( int j = 0 ; j < mm ; j + + ) {
register double tmp ;
# ifdef MATPTR
tmp = v [ i ] [ j ] ;
# else
tmp = v [ i * mm + j ] ;
# endif
if ( i = = j ) tmp - = scalar ;
sum + = tmp * tmp ;
}
2009-11-12 22:01:19 +01:00
return std : : sqrt ( sum ) ;
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
const double NRMat < complex < double > > : : norm ( const complex < double > scalar ) const
{
if ( scalar = = CZERO ) return cblas_dznrm2 ( nn * mm , ( * this ) [ 0 ] , 1 ) ;
double sum = 0 ;
for ( int i = 0 ; i < nn ; i + + )
for ( int j = 0 ; j < mm ; j + + ) {
register complex < double > tmp ;
# ifdef MATPTR
tmp = v [ i ] [ j ] ;
# else
tmp = v [ i * mm + j ] ;
# endif
if ( i = = j ) tmp - = scalar ;
sum + = tmp . real ( ) * tmp . real ( ) + tmp . imag ( ) * tmp . imag ( ) ;
}
2009-11-12 22:01:19 +01:00
return std : : sqrt ( sum ) ;
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// axpy: this = a * Mat
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < double > : : axpy ( const double alpha , const NRMat < double > & mat )
{
# ifdef DEBUG
if ( nn ! = mat . nn | | mm ! = mat . mm ) laerror ( " daxpy of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
cblas_daxpy ( nn * mm , alpha , mat , 1 , * this , 1 ) ;
}
2005-12-08 13:06:23 +01:00
template < >
2004-03-17 04:07:21 +01:00
void NRMat < complex < double > > : : axpy ( const complex < double > alpha ,
const NRMat < complex < double > > & mat )
{
# ifdef DEBUG
if ( nn ! = mat . nn | | mm ! = mat . mm ) laerror ( " zaxpy of incompatible matrices " ) ;
# endif
copyonwrite ( ) ;
2009-11-12 22:01:19 +01:00
cblas_zaxpy ( nn * mm , & alpha , mat , 1 , ( * this ) [ 0 ] , 1 ) ;
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// trace of Mat
2008-04-16 14:56:02 +02:00
template < typename T >
const T NRMat < T > : : trace ( ) const
2004-03-17 04:07:21 +01:00
{
# ifdef DEBUG
if ( nn ! = mm ) laerror ( " no-square matrix in Mat::trace() " ) ;
# endif
2008-04-16 14:56:02 +02:00
T sum = 0 ;
2004-03-17 04:07:21 +01:00
# ifdef MATPTR
2008-04-16 14:56:02 +02:00
for ( int i = 0 ; i < nn ; + + i ) sum + = v [ i ] [ i ] ;
2004-03-17 04:07:21 +01:00
# else
2008-04-16 14:56:02 +02:00
for ( int i = 0 ; i < nn * nn ; i + = ( nn + 1 ) ) sum + = v [ i ] ;
2004-03-17 04:07:21 +01:00
# endif
2008-04-16 14:56:02 +02:00
return sum ;
2004-03-17 04:07:21 +01:00
}
2005-02-02 15:49:33 +01:00
2008-04-16 14:56:02 +02:00
2005-02-02 15:49:33 +01:00
//get diagonal; for compatibility with large matrices do not return newly created object
//for non-square get diagonal of A^TA, will be used as preconditioner
2005-12-08 13:06:23 +01:00
template < >
2006-04-06 23:45:51 +02:00
const double * NRMat < double > : : diagonalof ( NRVec < double > & r , const bool divide , bool cache ) const
2005-02-02 15:49:33 +01:00
{
if ( r . size ( ) ! = nn ) laerror ( " diagonalof() incompatible vector " ) ;
2005-02-04 15:31:42 +01:00
double a ;
r . copyonwrite ( ) ;
2005-02-02 15:49:33 +01:00
if ( nn = = mm )
{
# ifdef MATPTR
2005-02-04 15:31:42 +01:00
if ( divide ) for ( int i = 0 ; i < nn ; i + + ) if ( ( a = v [ i ] [ i ] ) ) r [ i ] / = a ;
else for ( int i = 0 ; i < nn ; i + + ) r [ i ] = v [ i ] [ i ] ;
2005-02-02 15:49:33 +01:00
# else
2005-02-04 15:31:42 +01:00
if ( divide ) { int i , j ; for ( i = j = 0 ; j < nn ; + + j , i + = nn + 1 ) if ( ( a = v [ i ] ) ) r [ j ] / = a ; }
else { int i , j ; for ( i = j = 0 ; j < nn ; + + j , i + = nn + 1 ) r [ j ] = v [ i ] ; }
2005-02-02 15:49:33 +01:00
# endif
}
else //non-square
{
for ( int i = 0 ; i < mm ; i + + )
2005-02-04 15:31:42 +01:00
{
2005-02-02 15:49:33 +01:00
# ifdef MATPTR
2005-02-04 15:31:42 +01:00
a = cblas_ddot ( nn , v [ 0 ] + i , mm , v [ 0 ] + i , mm ) ;
2005-02-02 15:49:33 +01:00
# else
2005-02-04 15:31:42 +01:00
a = cblas_ddot ( nn , v + i , mm , v + i , mm ) ;
2005-02-02 15:49:33 +01:00
# endif
2005-02-04 15:31:42 +01:00
if ( divide ) { if ( a ) r [ i ] / = a ; }
else r [ i ] = a ;
}
2005-02-02 15:49:33 +01:00
}
2006-04-06 23:45:51 +02:00
return divide ? NULL : & r [ 0 ] ;
2005-02-02 15:49:33 +01:00
}
2008-03-01 17:55:18 +01:00
//set diagonal
template < >
void NRMat < double > : : diagonalset ( const NRVec < double > & r )
{
if ( r . size ( ) ! = nn ) laerror ( " diagonalset() incompatible vector " ) ;
if ( nn ! = mm ) laerror ( " diagonalset only for square matrix " ) ;
copyonwrite ( ) ;
# ifdef MATPTR
for ( int i = 0 ; i < nn ; i + + ) v [ i ] [ i ] = r [ i ] ;
# else
{ int i , j ; for ( i = j = 0 ; j < nn ; + + j , i + = nn + 1 ) v [ i ] = r [ j ] ; }
# endif
}
2009-10-19 21:38:57 +02:00
template < >
void NRMat < double > : : orthonormalize ( const bool rowcol , const NRSMat < double > * metric ) //modified Gram-Schmidt
{
if ( metric ) //general metric
{
if ( rowcol ) //vectors are rows
{
if ( ( * metric ) . nrows ( ) ! = mm ) laerror ( " incompatible metric in orthonormalize " ) ;
for ( int j = 0 ; j < nn ; + + j )
{
for ( int i = 0 ; i < j ; + + i )
{
NRVec < double > tmp = * metric * ( * this ) . row ( i ) ;
double fact = cblas_ddot ( mm , ( * this ) [ j ] , 1 , tmp , 1 ) ;
cblas_daxpy ( mm , - fact , ( * this ) [ i ] , 1 , ( * this ) [ j ] , 1 ) ;
}
NRVec < double > tmp = * metric * ( * this ) . row ( j ) ;
double norm = cblas_ddot ( mm , ( * this ) [ j ] , 1 , tmp , 1 ) ;
if ( norm < = 0. ) laerror ( " zero vector in orthonormalize or nonpositive metric " ) ;
2009-11-12 22:01:19 +01:00
cblas_dscal ( mm , 1. / std : : sqrt ( norm ) , ( * this ) [ j ] , 1 ) ;
2009-10-19 21:38:57 +02:00
}
}
else //vectors are columns
{
if ( ( * metric ) . nrows ( ) ! = nn ) laerror ( " incompatible metric in orthonormalize " ) ;
for ( int j = 0 ; j < mm ; + + j )
{
for ( int i = 0 ; i < j ; + + i )
{
NRVec < double > tmp = * metric * ( * this ) . column ( i ) ;
double fact = cblas_ddot ( nn , & ( * this ) [ 0 ] [ j ] , mm , tmp , 1 ) ;
cblas_daxpy ( nn , - fact , & ( * this ) [ 0 ] [ i ] , mm , & ( * this ) [ 0 ] [ j ] , mm ) ;
}
NRVec < double > tmp = * metric * ( * this ) . column ( j ) ;
double norm = cblas_ddot ( nn , & ( * this ) [ 0 ] [ j ] , mm , tmp , 1 ) ;
if ( norm < = 0. ) laerror ( " zero vector in orthonormalize or nonpositive metric " ) ;
2009-11-12 22:01:19 +01:00
cblas_dscal ( nn , 1. / std : : sqrt ( norm ) , & ( * this ) [ 0 ] [ j ] , mm ) ;
2009-10-19 21:38:57 +02:00
}
}
}
else //unit metric
2008-03-01 17:55:18 +01:00
2009-10-19 21:38:57 +02:00
{
if ( rowcol ) //vectors are rows
{
for ( int j = 0 ; j < nn ; + + j )
{
for ( int i = 0 ; i < j ; + + i )
{
double fact = cblas_ddot ( mm , ( * this ) [ j ] , 1 , ( * this ) [ i ] , 1 ) ;
cblas_daxpy ( mm , - fact , ( * this ) [ i ] , 1 , ( * this ) [ j ] , 1 ) ;
}
double norm = cblas_dnrm2 ( mm , ( * this ) [ j ] , 1 ) ;
if ( norm = = 0. ) laerror ( " zero vector in orthonormalize " ) ;
cblas_dscal ( mm , 1. / norm , ( * this ) [ j ] , 1 ) ;
}
}
else //vectors are columns
{
for ( int j = 0 ; j < mm ; + + j )
{
for ( int i = 0 ; i < j ; + + i )
{
double fact = cblas_ddot ( nn , & ( * this ) [ 0 ] [ j ] , mm , & ( * this ) [ 0 ] [ i ] , mm ) ;
cblas_daxpy ( nn , - fact , & ( * this ) [ 0 ] [ i ] , mm , & ( * this ) [ 0 ] [ j ] , mm ) ;
}
double norm = cblas_dnrm2 ( nn , & ( * this ) [ 0 ] [ j ] , mm ) ;
if ( norm = = 0. ) laerror ( " zero vector in orthonormalize " ) ;
cblas_dscal ( nn , 1. / norm , & ( * this ) [ 0 ] [ j ] , mm ) ;
}
}
}
}
2005-02-02 15:49:33 +01:00
2004-03-17 04:07:21 +01:00
2010-06-25 17:28:19 +02:00
//------------------------------------------------------------------------------
// for a matrix A(1:nn,1:mm) performs Fortran-like
// operation A(nn:-1:1,:)
//------------------------------------------------------------------------------
template < >
NRMat < double > & NRMat < double > : : SwapRows ( ) {
copyonwrite ( ) ;
const int n_pul = this - > nn / 2 ;
double * const dataIn = this - > v ;
for ( register int i = 0 ; i < n_pul ; i + + ) {
cblas_dswap ( mm , dataIn + i * mm , 1 , dataIn + ( nn - i - 1 ) * mm , 1 ) ;
}
return * this ;
}
//------------------------------------------------------------------------------
template < >
NRMat < complex < double > > & NRMat < complex < double > > : : SwapRows ( ) {
copyonwrite ( ) ;
const int n = this - > nn ;
const int m = this - > mm ;
const int n_pul = this - > nn / 2 ;
complex < double > * const dataIn = this - > v ;
for ( register int i = 0 ; i < n_pul ; i + + ) {
cblas_zswap ( m , dataIn + i * m , 1 , dataIn + ( n - i - 1 ) * m , 1 ) ;
}
return * this ;
}
//------------------------------------------------------------------------------
template < typename T >
NRMat < T > & NRMat < T > : : SwapRows ( ) {
copyonwrite ( ) ;
const int n = this - > nn ;
const int m = this - > mm ;
const int n_pul = this - > nn / 2 ;
T * const dataIn = this - > v ;
for ( register int i = 0 ; i < n_pul ; i + + ) {
const int offset1 = i * m ;
const int offset2 = ( n - i - 1 ) * m ;
for ( register int j = 0 ; j < m ; j + + ) {
dataIn [ offset1 + j ] = dataIn [ offset2 + j ] ;
}
}
return * this ;
}
//------------------------------------------------------------------------------
// for a matrix A(1:nn,1:mm) performs Fortran-like
// operation A(:,mm:-1:1)
//------------------------------------------------------------------------------
template < >
NRMat < double > & NRMat < double > : : SwapCols ( ) {
copyonwrite ( ) ;
const int n = this - > nn ;
const int m = this - > mm ;
const int m_pul = m / 2 ;
double * const dataIn = this - > v ;
for ( register int i = 0 ; i < m_pul ; i + + ) {
cblas_dswap ( n , dataIn + i , m , dataIn + ( m - i - 1 ) , m ) ;
}
return * this ;
}
//------------------------------------------------------------------------------
template < >
NRMat < complex < double > > & NRMat < complex < double > > : : SwapCols ( ) {
copyonwrite ( ) ;
const int n_pul = this - > nn / 2 ;
const int m_pul = this - > mm / 2 ;
complex < double > * const dataIn = this - > v ;
for ( register int i = 0 ; i < m_pul ; i + + ) {
cblas_zswap ( nn , dataIn + i , mm , dataIn + ( mm - i - 1 ) , mm ) ;
}
return * this ;
}
//------------------------------------------------------------------------------
template < typename T >
NRMat < T > & NRMat < T > : : SwapCols ( ) {
copyonwrite ( ) ;
const int n_pul = nn / 2 ;
const int m_pul = mm / 2 ;
T * const dataIn = this - > v ;
for ( register int i = 0 ; i < m_pul ; i + + ) {
for ( register int j = 0 ; j < nn ; j + + ) {
const int jm = j * mm ;
dataIn [ i + jm ] = dataIn [ ( mm - i - 1 ) + jm ] ;
}
}
return * this ;
}
//------------------------------------------------------------------------------
// for a matrix A(1:nn,1:mm) performs Fortran-like
// operation A(nn:-1:1,mm:-1:1)
//------------------------------------------------------------------------------
template < typename T >
NRMat < T > & NRMat < T > : : SwapRowsCols ( ) {
this - > copyonwrite ( ) ;
const int n = this - > nn ;
const int m = this - > mm ;
T * const dataIn = this - > v ;
T * const dataOut = this - > v ;
const int Dim = n * m ;
for ( register int i = 0 ; i < n ; i + + ) {
const int off = i * n ;
for ( register int j = 0 ; j < m ; j + + ) {
const int offset = off + j ;
dataOut [ Dim - ( offset + 1 ) ] = dataIn [ offset ] ;
}
}
return * this ;
}
//------------------------------------------------------------------------------
2004-03-17 04:07:21 +01:00
2006-09-10 22:06:44 +02:00
//////////////////////////////////////////////////////////////////////////////
//// forced instantization in the corresponding object file
template class NRMat < double > ;
template class NRMat < complex < double > > ;
2009-10-19 21:38:57 +02:00
template class NRMat < long long > ;
template class NRMat < long > ;
2006-09-10 22:06:44 +02:00
template class NRMat < int > ;
template class NRMat < short > ;
template class NRMat < char > ;
template class NRMat < unsigned char > ;
2009-10-19 21:38:57 +02:00
template class NRMat < unsigned short > ;
2006-09-10 22:06:44 +02:00
template class NRMat < unsigned int > ;
template class NRMat < unsigned long > ;
2009-10-19 21:38:57 +02:00
template class NRMat < unsigned long long > ;
2006-09-10 22:06:44 +02:00
2009-11-12 22:01:19 +01:00
} //namespace