LA_library/mat.cc

1266 lines
28 KiB
C++
Raw Normal View History

2008-02-26 14:55:23 +01:00
/*
LA: linear algebra C++ interface library
Copyright (C) 2008 Jiri Pittner <jiri.pittner@jh-inst.cas.cz> or <jiri@pittnerovi.com>
complex versions written by Roman Curik <roman.curik@jh-inst.cas.cz>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
2004-03-17 04:07:21 +01:00
#include "mat.h"
2005-02-14 01:10:07 +01:00
#include <stdlib.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
2005-12-08 13:06:23 +01:00
#include <errno.h>
2005-02-14 01:10:07 +01:00
extern "C" {
extern ssize_t read(int, void *, size_t);
extern ssize_t write(int, const void *, size_t);
}
2004-03-17 04:07:21 +01:00
// TODO :
//
2009-11-12 22:01:19 +01:00
namespace LA {
2004-03-17 04:07:21 +01:00
/*
* Templates first, specializations for BLAS next
2006-04-01 06:48:01 +02:00
*/
2007-06-22 16:24:55 +02:00
//direct sum
template <typename T>
const NRMat<T> NRMat<T>::oplus(const NRMat<T> &rhs) const
{
2007-06-22 16:46:03 +02:00
if(nn==0 && mm == 0) return rhs;
if(rhs.nn==0 && rhs.mm== 0) return *this;
2007-06-22 16:24:55 +02:00
NRMat<T> r((T)0,nn+rhs.nn,mm+rhs.mm);
#ifdef oldversion
int i,j;
for(i=0;i<nn;i++) for(j=0;j<mm;j++) r(i,j)=(*this)(i,j);
for(i=0;i<nn;i++) for(j=mm;j<mm+rhs.mm;j++) r(i,j)= (T)0;
for(i=nn;i<nn+rhs.nn;i++) for(j=0;j<mm;j++) r(i,j)= (T)0;
for(i=nn;i<nn+rhs.nn;i++) for(j=mm;j<mm+rhs.mm;j++) r(i,j)= rhs(i-nn,j-mm);
#else
r.storesubmatrix(0,0,*this);
r.storesubmatrix(nn,mm,rhs);
#endif
return r;
}
//direct product
template <typename T>
2007-06-23 23:09:39 +02:00
const NRMat<T> NRMat<T>::otimes(const NRMat<T> &rhs, bool reversecolumns) const
2007-06-22 16:24:55 +02:00
{
2007-06-22 16:46:03 +02:00
if(nn==0 && mm == 0) return *this;
if(rhs.nn==0 && rhs.mm== 0) return rhs;
2007-06-22 16:24:55 +02:00
NRMat<T> r((T)0,nn*rhs.nn,mm*rhs.mm);
int i,j,k,l;
2007-06-23 23:09:39 +02:00
if(reversecolumns)
{
for(i=0;i<nn;i++) for(j=0;j<mm;j++)
{
T c=(*this)(i,j);
for(k=0;k<rhs.mm;k++) for(l=0;l<rhs.mm;l++)
r( i*rhs.nn+k , l*nn+j ) = c *rhs(k,l);
}
}
else
{
2007-06-22 16:24:55 +02:00
for(i=0;i<nn;i++) for(j=0;j<mm;j++)
{
T c=(*this)(i,j);
for(k=0;k<rhs.mm;k++) for(l=0;l<rhs.mm;l++)
r( i*rhs.nn+k , j*rhs.nn+l ) = c *rhs(k,l);
}
2007-06-23 23:09:39 +02:00
}
2007-06-22 16:24:55 +02:00
return r;
}
2006-04-01 06:48:01 +02:00
//row of
template <typename T>
2006-09-13 23:29:28 +02:00
const NRVec<T> NRMat<T>::row(const int i, int l) const
2006-04-01 06:48:01 +02:00
{
#ifdef DEBUG
if(i<0||i>=nn) laerror("illegal index in row()");
#endif
2006-09-13 23:29:28 +02:00
if(l < 0) l=mm;
NRVec<T> r(l);
2006-04-01 06:48:01 +02:00
LA_traits<T>::copy(&r[0],
#ifdef MATPTR
v[i]
#else
2006-09-13 23:29:28 +02:00
v+i*l
2006-04-01 06:48:01 +02:00
#endif
2006-09-13 23:29:28 +02:00
,l);
2006-04-01 06:48:01 +02:00
return r;
}
2004-03-17 04:07:21 +01:00
2005-02-14 01:10:07 +01:00
//raw I/O
template <typename T>
2005-09-11 22:04:24 +02:00
void NRMat<T>::put(int fd, bool dim, bool transp) const
2005-02-14 01:10:07 +01:00
{
errno=0;
if(dim)
{
2005-09-11 22:04:24 +02:00
if(sizeof(int) != write(fd,&(transp?mm:nn),sizeof(int))) laerror("cannot write");
if(sizeof(int) != write(fd,&(transp?nn:mm),sizeof(int))) laerror("cannot write");
2005-02-14 01:10:07 +01:00
}
2005-09-11 22:04:24 +02:00
if(transp) //not particularly efficient
{
for(int j=0; j<mm; ++j)
for(int i=0; i<nn; ++i)
LA_traits<T>::put(fd,
#ifdef MATPTR
v[i][j]
#else
v[i*mm+j]
#endif
,dim,transp);
}
else LA_traits<T>::multiput(nn*mm,fd,
2005-02-14 01:10:07 +01:00
#ifdef MATPTR
v[0]
#else
v
#endif
,dim);
}
template <typename T>
2005-09-11 22:04:24 +02:00
void NRMat<T>::get(int fd, bool dim, bool transp)
2005-02-14 01:10:07 +01:00
{
int nn0,mm0;
errno=0;
if(dim)
{
if(sizeof(int) != read(fd,&nn0,sizeof(int))) laerror("cannot read");
if(sizeof(int) != read(fd,&mm0,sizeof(int))) laerror("cannot read");
2005-09-11 22:04:24 +02:00
if(transp) resize(mm0,nn0); else resize(nn0,mm0);
2005-02-14 01:10:07 +01:00
}
else
copyonwrite();
2005-09-11 22:04:24 +02:00
if(transp) //not particularly efficient
{
for(int j=0; j<mm; ++j)
for(int i=0; i<nn; ++i)
LA_traits<T>::get(fd,
#ifdef MATPTR
v[i][j]
#else
v[i*mm+j]
#endif
,dim,transp);
}
else LA_traits<T>::multiget(nn*mm,fd,
2005-02-14 01:10:07 +01:00
#ifdef MATPTR
v[0]
#else
v
#endif
,dim);
}
2004-03-17 04:07:21 +01:00
// Assign diagonal
template <typename T>
NRMat<T> & NRMat<T>::operator=(const T &a)
{
copyonwrite();
#ifdef DEBUG
if (nn != mm) laerror("RMat.operator=scalar on non-square matrix");
#endif
#ifdef MATPTR
2006-09-04 22:12:34 +02:00
memset(v[0],0,nn*nn*sizeof(T));
2004-03-17 04:07:21 +01:00
for (int i=0; i< nn; i++) v[i][i] = a;
#else
2006-09-04 22:12:34 +02:00
memset(v,0,nn*nn*sizeof(T));
2004-03-17 04:07:21 +01:00
for (int i=0; i< nn*nn; i+=nn+1) v[i] = a;
#endif
return *this;
}
2005-02-01 00:08:03 +01:00
2004-03-17 04:07:21 +01:00
// M += a
template <typename T>
NRMat<T> & NRMat<T>::operator+=(const T &a)
{
copyonwrite();
#ifdef DEBUG
if (nn != mm) laerror("Mat.operator+=scalar on non-square matrix");
#endif
#ifdef MATPTR
for (int i=0; i< nn; i++) v[i][i] += a;
#else
for (int i=0; i< nn*nn; i+=nn+1) v[i] += a;
#endif
return *this;
}
// M -= a
template <typename T>
NRMat<T> & NRMat<T>::operator-=(const T &a)
{
copyonwrite();
#ifdef DEBUG
if (nn != mm) laerror("Mat.operator-=scalar on non-square matrix");
#endif
#ifdef MATPTR
for (int i=0; i< nn; i++) v[i][i] -= a;
#else
for (int i=0; i< nn*nn; i+=nn+1) v[i] -= a;
#endif
return *this;
}
// unary minus
template <typename T>
const NRMat<T> NRMat<T>::operator-() const
{
NRMat<T> result(nn, mm);
#ifdef MATPTR
for (int i=0; i<nn*mm; i++) result.v[0][i]= -v[0][i];
#else
for (int i=0; i<nn*mm; i++) result.v[i]= -v[i];
#endif
return result;
}
// direct sum
template <typename T>
const NRMat<T> NRMat<T>::operator&(const NRMat<T> & b) const
{
NRMat<T> result((T)0, nn+b.nn, mm+b.mm);
for (int i=0; i<nn; i++) memcpy(result[i], (*this)[i], sizeof(T)*mm);
for (int i=0; i<b.nn; i++) memcpy(result[nn+i]+nn, b[i], sizeof(T)*b.mm);
return result;
}
// direct product
template <typename T>
const NRMat<T> NRMat<T>::operator|(const NRMat<T> &b) const
{
NRMat<T> result(nn*b.nn, mm*b.mm);
for (int i=0; i<nn; i++)
for (int j=0; j<mm; j++)
for (int k=0; k<b.nn; k++)
for (int l=0; l<b.mm; l++)
result[i*b.nn+k][j*b.mm+l] = (*this)[i][j]*b[k][l];
return result;
}
// sum of columns
template <typename T>
const NRVec<T> NRMat<T>::csum() const
{
NRVec<T> result(nn);
T sum;
for (int i=0; i<nn; i++) {
sum = (T)0;
for(int j=0; j<mm; j++) sum += (*this)[i][j];
result[i] = sum;
}
return result;
}
// sum of rows
template <typename T>
const NRVec<T> NRMat<T>::rsum() const
{
NRVec<T> result(nn);
T sum;
for (int i=0; i<mm; i++) {
sum = (T)0;
for(int j=0; j<nn; j++) sum += (*this)[j][i];
result[i] = sum;
}
return result;
}
2005-09-11 22:04:24 +02:00
//block submatrix
template <typename T>
const NRMat<T> NRMat<T>::submatrix(const int fromrow, const int torow, const int fromcol, const int tocol) const
{
#ifdef DEBUG
if(fromrow <0 ||fromrow >=nn||torow <0 ||torow >=nn ||fromcol<0||fromcol>=mm||tocol<0||tocol>=mm||fromrow>torow||fromcol>tocol) laerror("bad indices in submatrix");
#endif
int n=torow-fromrow+1;
int m=tocol-fromcol+1;
NRMat<T> r(n,m);
for(int i=fromrow; i<=torow; ++i)
#ifdef MATPTR
memcpy(r.v[i-fromrow],v[i]+fromcol,m*sizeof(T));
#else
memcpy(r.v+(i-fromrow)*m,v+i*mm+fromcol,m*sizeof(T));
#endif
return r;
}
2004-03-17 04:07:21 +01:00
2006-10-21 22:14:13 +02:00
template <typename T>
void NRMat<T>::storesubmatrix(const int fromrow, const int fromcol, const NRMat &rhs)
{
int tocol=fromcol+rhs.ncols()-1;
int torow=fromrow+rhs.nrows()-1;
#ifdef DEBUG
if(fromrow <0 ||fromrow >=nn||torow >=nn ||fromcol<0||fromcol>=mm||tocol>=mm) laerror("bad indices in storesubmatrix");
#endif
int m=tocol-fromcol+1;
for(int i=fromrow; i<=torow; ++i)
#ifdef MATPTR
memcpy(v[i]+fromcol,rhs.v[i-fromrow],m*sizeof(T));
#else
memcpy(v+i*mm+fromcol,rhs.v+(i-fromrow)*m,m*sizeof(T));
#endif
}
2004-03-17 04:07:21 +01:00
// transpose Mat
template <typename T>
2005-02-18 23:08:15 +01:00
NRMat<T> & NRMat<T>::transposeme(int n)
2004-03-17 04:07:21 +01:00
{
2005-02-18 23:08:15 +01:00
if(n==0) n=nn;
2004-03-17 04:07:21 +01:00
#ifdef DEBUG
2005-02-18 23:08:15 +01:00
if (n==nn && nn != mm || n>mm || n>nn) laerror("transpose of non-square Mat");
2004-03-17 04:07:21 +01:00
#endif
copyonwrite();
2005-02-18 23:08:15 +01:00
for(int i=1; i<n; i++)
2004-03-17 04:07:21 +01:00
for(int j=0; j<i; j++) {
#ifdef MATPTR
T tmp = v[i][j];
v[i][j] = v[j][i];
v[j][i] = tmp;
#else
register int a;
register int b;
a = i*mm+j;
b = j*mm+i;
T tmp = v[a];
v[a] = v[b];
v[b] = tmp;
#endif
}
return *this;
}
2009-10-08 16:01:15 +02:00
//complex from real
template<>
NRMat<complex<double> >::NRMat(const NRMat<double> &rhs, bool imagpart)
: nn(rhs.nrows()), mm(rhs.ncols()), count(new int(1))
{
#ifdef MATPTR
v = new complex<double>*[n];
v[0] = new complex<double>[mm*nn];
for (int i=1; i<n; i++) v[i] = v[i-1] + m;
memset(v[0], 0, nn*mm*sizeof(complex<double>));
cblas_dcopy(nn*mm,&rhs[0][0],1,((double *)v[0]) + (imagpart?1:0),2);
#else
v = new complex<double>[mm*nn];
memset(v, 0, nn*mm*sizeof(complex<double>));
cblas_dcopy(nn*mm,&rhs[0][0],1,((double *)v) + (imagpart?1:0),2);
#endif
}
2004-03-17 04:07:21 +01:00
// Output of Mat
template <typename T>
void NRMat<T>::fprintf(FILE *file, const char *format, const int modulo) const
{
lawritemat(file, (const T*)(*this), nn, mm, format, 2, modulo, 0);
}
// Input of Mat
template <typename T>
void NRMat<T>::fscanf(FILE *f, const char *format)
{
int n, m;
if (std::fscanf(f, "%d %d", &n, &m) != 2)
laerror("cannot read matrix dimensions in Mat::fscanf()");
resize(n,m);
T *p = *this;
for(int i=0; i<n; i++)
for(int j=0; j<n; j++)
if(std::fscanf(f,format,p++) != 1)
laerror("cannot read matrix element in Mat::fscanf()");
}
/*
* BLAS specializations for double and complex<double>
*/
2006-04-01 06:48:01 +02:00
template<>
const NRSMat<double> NRMat<double>::transposedtimes() const
{
NRSMat<double> r(mm,mm);
int i,j;
for(i=0; i<mm; ++i) for(j=0; j<=i; ++j)
#ifdef MATPTR
r(i,j) = cblas_ddot(nn,v[0]+i,mm,v[0]+j,mm);
#else
r(i,j) = cblas_ddot(nn,v+i,mm,v+j,mm);
#endif
return r;
}
template<>
const NRSMat<complex<double> > NRMat<complex<double> >::transposedtimes() const
{
NRSMat<complex<double> > r(mm,mm);
int i,j;
for(i=0; i<mm; ++i) for(j=0; j<=i; ++j)
#ifdef MATPTR
2009-11-12 22:01:19 +01:00
cblas_zdotc_sub(nn, v[0]+i , mm,v[0]+j, mm, &r(i,j));
2006-04-01 06:48:01 +02:00
#else
2009-11-12 22:01:19 +01:00
cblas_zdotc_sub(nn, v+i , mm,v+j, mm, &r(i,j));
2006-04-01 06:48:01 +02:00
#endif
return r;
}
//and for general type
template <typename T>
const NRSMat<T> NRMat<T>::transposedtimes() const
{
NRSMat<T> r(mm,mm);
int i,j;
for(i=0; i<mm; ++i) for(j=0; j<=i; ++j)
{
T s =(T)0;
for(int k=0; k<nn; ++k) s+= (*this)(k,i) * (*this)(k,j);
r(i,j)=s;
}
return r;
}
template<>
const NRSMat<double> NRMat<double>::timestransposed() const
{
NRSMat<double> r(nn,nn);
int i,j;
for(i=0; i<nn; ++i) for(j=0; j<=i; ++j)
#ifdef MATPTR
r(i,j) = cblas_ddot(mm,v[i],1,v[j],1);
#else
r(i,j) = cblas_ddot(mm,v+i*mm,1,v+j*mm,1);
#endif
return r;
}
template<>
const NRSMat<complex<double> > NRMat<complex<double> >::timestransposed() const
{
NRSMat<complex<double> > r(nn,nn);
int i,j;
for(i=0; i<nn; ++i) for(j=0; j<=i; ++j)
#ifdef MATPTR
2009-11-12 22:01:19 +01:00
cblas_zdotc_sub(mm, v[i] , 1,v[j], 1, &r(i,j));
2006-04-01 06:48:01 +02:00
#else
2009-11-12 22:01:19 +01:00
cblas_zdotc_sub(mm, v+i*mm , 1,v+j*mm, 1, &r(i,j));
2006-04-01 06:48:01 +02:00
#endif
return r;
}
//and for general type
template <typename T>
const NRSMat<T> NRMat<T>::timestransposed() const
{
NRSMat<T> r(nn,nn);
int i,j;
for(i=0; i<nn; ++i) for(j=0; j<=i; ++j)
{
T s =(T)0;
for(int k=0; k<mm; ++k) s+= (*this)(i,k) * (*this)(j,k);
r(i,j)=s;
}
return r;
}
2008-03-03 16:35:37 +01:00
//randomize
template<>
void NRMat<double>::randomize(const double &x)
{
for(int i=0; i<nn; ++i)
for(int j=0; j<mm; ++j)
(*this)(i,j) = x*(2.*random()/(1.+RAND_MAX) -1.);
}
2009-10-08 16:01:15 +02:00
template<>
void NRMat<complex<double> >::randomize(const double &x)
{
for(int i=0; i<nn; ++i)
for(int j=0; j<mm; ++j)
{
(*this)(i,j).real() = x*(2.*random()/(1.+RAND_MAX) -1.);
(*this)(i,j).imag() = x*(2.*random()/(1.+RAND_MAX) -1.);
}
}
2006-04-01 06:48:01 +02:00
2004-03-17 04:07:21 +01:00
// Mat *= a
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat<double> & NRMat<double>::operator*=(const double &a)
{
copyonwrite();
cblas_dscal(nn*mm, a, *this, 1);
return *this;
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat< complex<double> > &
NRMat< complex<double> >::operator*=(const complex<double> &a)
{
copyonwrite();
2009-11-12 22:01:19 +01:00
cblas_zscal(nn*mm, &a, (*this)[0], 1);
2004-03-17 04:07:21 +01:00
return *this;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template <typename T>
NRMat<T> & NRMat<T>::operator*=(const T &a)
{
copyonwrite();
#ifdef MATPTR
for (int i=0; i< nn*nn; i++) v[0][i] *= a;
#else
for (int i=0; i< nn*nn; i++) v[i] *= a;
#endif
return *this;
}
2004-03-17 04:07:21 +01:00
// Mat += Mat
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat<double> & NRMat<double>::operator+=(const NRMat<double> &rhs)
{
#ifdef DEBUG
if (nn != rhs.nn || mm!= rhs.mm)
laerror("Mat += Mat of incompatible matrices");
#endif
copyonwrite();
cblas_daxpy(nn*mm, 1.0, rhs, 1, *this, 1);
return *this;
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat< complex<double> > &
NRMat< complex<double> >::operator+=(const NRMat< complex<double> > &rhs)
{
#ifdef DEBUG
if (nn != rhs.nn || mm!= rhs.mm)
laerror("Mat += Mat of incompatible matrices");
#endif
copyonwrite();
2009-11-12 22:01:19 +01:00
cblas_zaxpy(nn*mm, &CONE, rhs[0], 1, (*this)[0], 1);
2004-03-17 04:07:21 +01:00
return *this;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template <typename T>
NRMat<T> & NRMat<T>::operator+=(const NRMat<T> &rhs)
{
#ifdef DEBUG
if (nn != rhs.nn || mm!= rhs.mm)
laerror("Mat -= Mat of incompatible matrices");
#endif
copyonwrite();
#ifdef MATPTR
for (int i=0; i< nn*nn; i++) v[0][i] += rhs.v[0][i] ;
#else
for (int i=0; i< nn*nn; i++) v[i] += rhs.v[i] ;
#endif
return *this;
}
2004-03-17 04:07:21 +01:00
// Mat -= Mat
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat<double> & NRMat<double>::operator-=(const NRMat<double> &rhs)
{
#ifdef DEBUG
if (nn != rhs.nn || mm!= rhs.mm)
laerror("Mat -= Mat of incompatible matrices");
#endif
copyonwrite();
cblas_daxpy(nn*mm, -1.0, rhs, 1, *this, 1);
return *this;
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat< complex<double> > &
NRMat< complex<double> >::operator-=(const NRMat< complex<double> > &rhs)
{
#ifdef DEBUG
if (nn != rhs.nn || mm!= rhs.mm)
laerror("Mat -= Mat of incompatible matrices");
#endif
copyonwrite();
2009-11-12 22:01:19 +01:00
cblas_zaxpy(nn*mm, &CMONE, rhs[0], 1, (*this)[0], 1);
2004-03-17 04:07:21 +01:00
return *this;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template <typename T>
NRMat<T> & NRMat<T>::operator-=(const NRMat<T> &rhs)
{
#ifdef DEBUG
if (nn != rhs.nn || mm!= rhs.mm)
laerror("Mat -= Mat of incompatible matrices");
#endif
copyonwrite();
#ifdef MATPTR
for (int i=0; i< nn*nn; i++) v[0][i] -= rhs.v[0][i] ;
#else
for (int i=0; i< nn*nn; i++) v[i] -= rhs.v[i] ;
#endif
return *this;
}
2004-03-17 04:07:21 +01:00
// Mat += SMat
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat<double> & NRMat<double>::operator+=(const NRSMat<double> &rhs)
{
#ifdef DEBUG
if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat+=SMat");
#endif
const double *p = rhs;
copyonwrite();
for (int i=0; i<nn; i++) {
cblas_daxpy(i+1, 1.0, p, 1, (*this)[i], 1);
p += i+1;
}
p = rhs; p++;
for (int i=1; i<nn; i++) {
cblas_daxpy(i, 1.0, p, 1, (*this)[0]+i, nn);
p += i+1;
}
return *this;
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat< complex<double> > &
NRMat< complex<double> >::operator+=(const NRSMat< complex<double> > &rhs)
{
#ifdef DEBUG
if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat+=SMat");
#endif
const complex<double> *p = rhs;
copyonwrite();
for (int i=0; i<nn; i++) {
2009-11-12 22:01:19 +01:00
cblas_zaxpy(i+1, &CONE, p, 1, (*this)[i], 1);
2004-03-17 04:07:21 +01:00
p += i+1;
}
p = rhs; p++;
for (int i=1; i<nn; i++) {
2009-11-12 22:01:19 +01:00
cblas_zaxpy(i, &CONE, p, 1, (*this)[0]+i, nn);
2004-03-17 04:07:21 +01:00
p += i+1;
}
return *this;
}
2005-12-08 13:06:23 +01:00
2004-03-17 06:34:59 +01:00
//and for general type
template <typename T>
NRMat<T> & NRMat<T>::operator+=(const NRSMat<T> &rhs)
{
#ifdef DEBUG
if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat+=SMat");
#endif
const T *p = rhs;
copyonwrite();
for (int i=0; i<nn; i++) {
for(int j=0; j<i+1; ++j) *((*this)[i]+j) += p[j];
p += i+1;
}
p = rhs; p++;
for (int i=1; i<nn; i++) {
2009-11-12 22:01:19 +01:00
for(int j=0; j<i; ++j) *((*this)[j]+i) += p[j];
2004-03-17 06:34:59 +01:00
p += i+1;
}
return *this;
}
2004-03-17 04:07:21 +01:00
// Mat -= SMat
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat<double> & NRMat<double>::operator-=(const NRSMat<double> &rhs)
{
#ifdef DEBUG
if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat-=SMat");
#endif
const double *p = rhs;
copyonwrite();
for (int i=0; i<nn; i++) {
cblas_daxpy(i+1, -1.0, p, 1, (*this)[i], 1);
p += i+1;
}
p = rhs; p++;
for (int i=1; i<nn; i++) {
cblas_daxpy(i, -1.0, p, 1, (*this)[0]+i, nn);
p += i+1;
}
return *this;
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat< complex<double> > &
NRMat< complex<double> >::operator-=(const NRSMat< complex<double> > &rhs)
{
#ifdef DEBUG
if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat-=SMat");
#endif
const complex<double> *p = rhs;
copyonwrite();
for (int i=0; i<nn; i++) {
2009-11-12 22:01:19 +01:00
cblas_zaxpy(i+1, &CMONE, p, 1, (*this)[i], 1);
2004-03-17 04:07:21 +01:00
p += i+1;
}
p = rhs; p++;
for (int i=1; i<nn; i++) {
2009-11-12 22:01:19 +01:00
cblas_zaxpy(i, &CMONE, p, 1, (*this)[0]+i, nn);
2004-03-17 04:07:21 +01:00
p += i+1;
}
return *this;
}
2004-03-17 06:34:59 +01:00
//and for general type
template <typename T>
NRMat<T> & NRMat<T>::operator-=(const NRSMat<T> &rhs)
{
#ifdef DEBUG
if (nn!=mm || nn!=rhs.nrows()) laerror("incompatible matrix size in Mat+=SMat");
#endif
const T *p = rhs;
copyonwrite();
for (int i=0; i<nn; i++) {
for(int j=0; j<i+1; ++j) *((*this)[i]+j) -= p[j];
p += i+1;
}
p = rhs; p++;
for (int i=1; i<nn; i++) {
2009-11-12 22:01:19 +01:00
for(int j=0; j<i; ++j) *((*this)[j]+i) -= p[j];
2004-03-17 06:34:59 +01:00
p += i+1;
}
return *this;
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Mat.Mat - scalar product
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
const double NRMat<double>::dot(const NRMat<double> &rhs) const
{
#ifdef DEBUG
if(nn!=rhs.nn || mm!= rhs.mm) laerror("Mat.Mat incompatible matrices");
#endif
return cblas_ddot(nn*mm, (*this)[0], 1, rhs[0], 1);
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
const complex<double>
NRMat< complex<double> >::dot(const NRMat< complex<double> > &rhs) const
{
#ifdef DEBUG
if(nn!=rhs.nn || mm!= rhs.mm) laerror("Mat.Mat incompatible matrices");
#endif
complex<double> dot;
2009-11-12 22:01:19 +01:00
cblas_zdotc_sub(nn*mm, (*this)[0], 1, rhs[0], 1,
&dot);
2004-03-17 04:07:21 +01:00
return dot;
}
// Mat * Mat
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
const NRMat<double> NRMat<double>::operator*(const NRMat<double> &rhs) const
{
#ifdef DEBUG
if (mm != rhs.nn) laerror("product of incompatible matrices");
2006-09-19 17:59:49 +02:00
if (rhs.mm <=0) laerror("illegal matrix dimension in gemm");
2004-03-17 04:07:21 +01:00
#endif
NRMat<double> result(nn, rhs.mm);
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, nn, rhs.mm, mm, 1.0,
*this, mm, rhs, rhs.mm, 0.0, result, rhs.mm);
return result;
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
const NRMat< complex<double> >
NRMat< complex<double> >::operator*(const NRMat< complex<double> > &rhs) const
{
#ifdef DEBUG
if (mm != rhs.nn) laerror("product of incompatible matrices");
#endif
NRMat< complex<double> > result(nn, rhs.mm);
cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, nn, rhs.mm, mm,
2009-11-12 22:01:19 +01:00
&CONE,(*this)[0], mm, rhs[0],
rhs.mm, &CZERO, result[0], rhs.mm);
2004-03-17 04:07:21 +01:00
return result;
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Multiply by diagonal from L
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRMat<double>::diagmultl(const NRVec<double> &rhs)
{
#ifdef DEBUG
if (nn != rhs.size()) laerror("incompatible matrix dimension in diagmultl");
#endif
copyonwrite();
for(int i=0; i<nn; i++) cblas_dscal(mm, rhs[i], (*this)[i], 1);
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRMat< complex<double> >::diagmultl(const NRVec< complex<double> > &rhs)
{
#ifdef DEBUG
if (nn != rhs.size()) laerror("incompatible matrix dimension in diagmultl");
#endif
copyonwrite();
for (int i=0; i<nn; i++) cblas_zscal(mm, &rhs[i], (*this)[i], 1);
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Multiply by diagonal from R
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRMat<double>::diagmultr(const NRVec<double> &rhs)
{
#ifdef DEBUG
if (mm != rhs.size()) laerror("incompatible matrix dimension in diagmultr");
#endif
copyonwrite();
2006-10-21 17:32:53 +02:00
for (int i=0; i<mm; i++) cblas_dscal(nn, rhs[i], &(*this)(0,i), mm);
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRMat< complex<double> >::diagmultr(const NRVec< complex<double> > &rhs)
{
#ifdef DEBUG
if (mm != rhs.size()) laerror("incompatible matrix dimension in diagmultl");
#endif
copyonwrite();
2006-10-21 17:32:53 +02:00
for (int i=0; i<mm; i++) cblas_zscal(nn, &rhs[i], &(*this)(0,i), mm);
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// Mat * Smat, decomposed to nn x Vec * Smat
2006-08-16 23:43:45 +02:00
//NOTE: dsymm is not appropriate as it works on UNPACKED symmetric matrix
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
const NRMat<double>
NRMat<double>::operator*(const NRSMat<double> &rhs) const
{
#ifdef DEBUG
if (mm != rhs.nrows()) laerror("incompatible dimension in Mat*SMat");
#endif
NRMat<double> result(nn, rhs.ncols());
for (int i=0; i<nn; i++)
cblas_dspmv(CblasRowMajor, CblasLower, mm, 1.0, &rhs[0],
(*this)[i], 1, 0.0, result[i], 1);
return result;
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
const NRMat< complex<double> >
NRMat< complex<double> >::operator*(const NRSMat< complex<double> > &rhs) const
{
#ifdef DEBUG
if (mm != rhs.nrows()) laerror("incompatible dimension in Mat*SMat");
#endif
NRMat< complex<double> > result(nn, rhs.ncols());
for (int i=0; i<nn; i++)
2009-11-12 22:01:19 +01:00
cblas_zhpmv(CblasRowMajor, CblasLower, mm, &CONE, &rhs[0],
(*this)[i], 1, &CZERO, result[i], 1);
2004-03-17 04:07:21 +01:00
return result;
}
2006-08-15 22:10:08 +02:00
2004-03-17 04:07:21 +01:00
// complex conjugate of Mat
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat<double> &NRMat<double>::conjugateme() {return *this;}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRMat< complex<double> > & NRMat< complex<double> >::conjugateme()
{
copyonwrite();
cblas_dscal(mm*nn, -1.0, (double *)((*this)[0])+1, 2);
return *this;
}
// transpose and optionally conjugate
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
const NRMat<double> NRMat<double>::transpose(bool conj) const
{
NRMat<double> result(mm,nn);
for(int i=0; i<nn; i++) cblas_dcopy(mm, (*this)[i], 1, result[0]+i, nn);
return result;
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
const NRMat< complex<double> >
NRMat< complex<double> >::transpose(bool conj) const
{
NRMat< complex<double> > result(mm,nn);
for (int i=0; i<nn; i++)
2009-11-12 22:01:19 +01:00
cblas_zcopy(mm, (*this)[i], 1, (result[0]+i), nn);
2004-03-17 04:07:21 +01:00
if (conj) cblas_dscal(mm*nn, -1.0, (double *)(result[0])+1, 2);
return result;
}
// gemm : this = alpha*op( A )*op( B ) + beta*this
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRMat<double>::gemm(const double &beta, const NRMat<double> &a,
const char transa, const NRMat<double> &b, const char transb,
const double &alpha)
{
int k(transa=='n'?a.mm:a.nn);
#ifdef DEBUG
2006-04-06 23:45:51 +02:00
int l(transa=='n'?a.nn:a.mm);
int kk(transb=='n'?b.nn:b.mm);
int ll(transb=='n'?b.mm:b.nn);
2004-03-17 04:07:21 +01:00
if (l!=nn || ll!=mm || k!=kk) laerror("incompatible matrices in Mat:gemm()");
2006-09-19 17:59:49 +02:00
if(b.mm <=0 || mm<=0) laerror("illegal matrix dimension in gemm");
2004-03-17 04:07:21 +01:00
#endif
if (alpha==0.0 && beta==1.0) return;
copyonwrite();
cblas_dgemm(CblasRowMajor, (transa=='n' ? CblasNoTrans : CblasTrans),
(transb=='n' ? CblasNoTrans : CblasTrans), nn, mm, k, alpha, a,
a.mm, b , b.mm, beta, *this , mm);
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRMat< complex<double> >::gemm(const complex<double> & beta,
const NRMat< complex<double> > & a, const char transa,
const NRMat< complex<double> > & b, const char transb,
const complex<double> & alpha)
{
int k(transa=='n'?a.mm:a.nn);
#ifdef DEBUG
2006-04-06 23:45:51 +02:00
int l(transa=='n'?a.nn:a.mm);
int kk(transb=='n'?b.nn:b.mm);
int ll(transb=='n'?b.mm:b.nn);
2004-03-17 04:07:21 +01:00
if (l!=nn || ll!=mm || k!=kk) laerror("incompatible matrices in Mat:gemm()");
#endif
if (alpha==CZERO && beta==CONE) return;
copyonwrite();
cblas_zgemm(CblasRowMajor,
(transa=='n' ? CblasNoTrans : (transa=='c'?CblasConjTrans:CblasTrans)),
(transb=='n' ? CblasNoTrans : (transa=='c'?CblasConjTrans:CblasTrans)),
nn, mm, k, &alpha, a , a.mm, b , b.mm, &beta, *this , mm);
}
// norm of Mat
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
const double NRMat<double>::norm(const double scalar) const
{
if (!scalar) return cblas_dnrm2(nn*mm, (*this)[0], 1);
double sum = 0;
for (int i=0; i<nn; i++)
for (int j=0; j<mm; j++) {
register double tmp;
#ifdef MATPTR
tmp = v[i][j];
#else
tmp = v[i*mm+j];
#endif
if (i==j) tmp -= scalar;
sum += tmp*tmp;
}
2009-11-12 22:01:19 +01:00
return std::sqrt(sum);
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
const double NRMat< complex<double> >::norm(const complex<double> scalar) const
{
if (scalar == CZERO) return cblas_dznrm2(nn*mm, (*this)[0], 1);
double sum = 0;
for (int i=0; i<nn; i++)
for (int j=0; j<mm; j++) {
register complex<double> tmp;
#ifdef MATPTR
tmp = v[i][j];
#else
tmp = v[i*mm+j];
#endif
if (i==j) tmp -= scalar;
sum += tmp.real()*tmp.real()+tmp.imag()*tmp.imag();
}
2009-11-12 22:01:19 +01:00
return std::sqrt(sum);
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// axpy: this = a * Mat
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRMat<double>::axpy(const double alpha, const NRMat<double> &mat)
{
#ifdef DEBUG
if (nn!=mat.nn || mm!=mat.mm) laerror("daxpy of incompatible matrices");
#endif
copyonwrite();
cblas_daxpy(nn*mm, alpha, mat, 1, *this, 1);
}
2005-12-08 13:06:23 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRMat< complex<double> >::axpy(const complex<double> alpha,
const NRMat< complex<double> > & mat)
{
#ifdef DEBUG
if (nn!=mat.nn || mm!=mat.mm) laerror("zaxpy of incompatible matrices");
#endif
copyonwrite();
2009-11-12 22:01:19 +01:00
cblas_zaxpy(nn*mm, &alpha, mat, 1, (*this)[0], 1);
2004-03-17 04:07:21 +01:00
}
2005-12-08 13:06:23 +01:00
2004-03-17 04:07:21 +01:00
// trace of Mat
2008-04-16 14:56:02 +02:00
template <typename T>
const T NRMat<T>::trace() const
2004-03-17 04:07:21 +01:00
{
#ifdef DEBUG
if (nn != mm) laerror("no-square matrix in Mat::trace()");
#endif
2008-04-16 14:56:02 +02:00
T sum=0;
2004-03-17 04:07:21 +01:00
#ifdef MATPTR
2008-04-16 14:56:02 +02:00
for (int i=0; i<nn; ++i) sum += v[i][i];
2004-03-17 04:07:21 +01:00
#else
2008-04-16 14:56:02 +02:00
for (int i=0; i<nn*nn; i+=(nn+1)) sum += v[i];
2004-03-17 04:07:21 +01:00
#endif
2008-04-16 14:56:02 +02:00
return sum;
2004-03-17 04:07:21 +01:00
}
2005-02-02 15:49:33 +01:00
2008-04-16 14:56:02 +02:00
2005-02-02 15:49:33 +01:00
//get diagonal; for compatibility with large matrices do not return newly created object
//for non-square get diagonal of A^TA, will be used as preconditioner
2005-12-08 13:06:23 +01:00
template<>
2006-04-06 23:45:51 +02:00
const double * NRMat<double>::diagonalof(NRVec<double> &r, const bool divide, bool cache) const
2005-02-02 15:49:33 +01:00
{
if (r.size() != nn) laerror("diagonalof() incompatible vector");
2005-02-04 15:31:42 +01:00
double a;
r.copyonwrite();
2005-02-02 15:49:33 +01:00
if(nn==mm)
{
#ifdef MATPTR
2005-02-04 15:31:42 +01:00
if(divide) for (int i=0; i< nn; i++) if((a=v[i][i])) r[i]/=a;
else for (int i=0; i< nn; i++) r[i] = v[i][i];
2005-02-02 15:49:33 +01:00
#else
2005-02-04 15:31:42 +01:00
if(divide) {int i,j; for (i=j=0; j< nn; ++j, i+=nn+1) if((a=v[i])) r[j] /=a;}
else {int i,j; for (i=j=0; j< nn; ++j, i+=nn+1) r[j] = v[i];}
2005-02-02 15:49:33 +01:00
#endif
}
else //non-square
{
for (int i=0; i< mm; i++)
2005-02-04 15:31:42 +01:00
{
2005-02-02 15:49:33 +01:00
#ifdef MATPTR
2005-02-04 15:31:42 +01:00
a= cblas_ddot(nn,v[0]+i,mm,v[0]+i,mm);
2005-02-02 15:49:33 +01:00
#else
2005-02-04 15:31:42 +01:00
a=cblas_ddot(nn,v+i,mm,v+i,mm);
2005-02-02 15:49:33 +01:00
#endif
2005-02-04 15:31:42 +01:00
if(divide) {if(a) r[i]/=a;}
else r[i] = a;
}
2005-02-02 15:49:33 +01:00
}
2006-04-06 23:45:51 +02:00
return divide?NULL:&r[0];
2005-02-02 15:49:33 +01:00
}
2008-03-01 17:55:18 +01:00
//set diagonal
template<>
void NRMat<double>::diagonalset(const NRVec<double> &r)
{
if (r.size() != nn) laerror("diagonalset() incompatible vector");
if(nn!=mm) laerror("diagonalset only for square matrix");
copyonwrite();
#ifdef MATPTR
for (int i=0; i< nn; i++) v[i][i] = r[i];
#else
{int i,j; for (i=j=0; j< nn; ++j, i+=nn+1) v[i] = r[j];}
#endif
}
2009-10-19 21:38:57 +02:00
template<>
void NRMat<double>::orthonormalize(const bool rowcol, const NRSMat<double> *metric) //modified Gram-Schmidt
{
if(metric) //general metric
{
if(rowcol) //vectors are rows
{
if((*metric).nrows() != mm) laerror("incompatible metric in orthonormalize");
for(int j=0; j<nn; ++j)
{
for(int i=0; i<j; ++i)
{
NRVec<double> tmp = *metric * (*this).row(i);
double fact = cblas_ddot(mm,(*this)[j],1,tmp,1);
cblas_daxpy(mm,-fact,(*this)[i],1,(*this)[j],1);
}
NRVec<double> tmp = *metric * (*this).row(j);
double norm = cblas_ddot(mm,(*this)[j],1,tmp,1);
if(norm<=0.) laerror("zero vector in orthonormalize or nonpositive metric");
2009-11-12 22:01:19 +01:00
cblas_dscal(mm,1./std::sqrt(norm),(*this)[j],1);
2009-10-19 21:38:57 +02:00
}
}
else //vectors are columns
{
if((*metric).nrows() != nn) laerror("incompatible metric in orthonormalize");
for(int j=0; j<mm; ++j)
{
for(int i=0; i<j; ++i)
{
NRVec<double> tmp = *metric * (*this).column(i);
double fact = cblas_ddot(nn,&(*this)[0][j],mm,tmp,1);
cblas_daxpy(nn,-fact,&(*this)[0][i],mm,&(*this)[0][j],mm);
}
NRVec<double> tmp = *metric * (*this).column(j);
double norm = cblas_ddot(nn,&(*this)[0][j],mm,tmp,1);
if(norm<=0.) laerror("zero vector in orthonormalize or nonpositive metric");
2009-11-12 22:01:19 +01:00
cblas_dscal(nn,1./std::sqrt(norm),&(*this)[0][j],mm);
2009-10-19 21:38:57 +02:00
}
}
}
else //unit metric
2008-03-01 17:55:18 +01:00
2009-10-19 21:38:57 +02:00
{
if(rowcol) //vectors are rows
{
for(int j=0; j<nn; ++j)
{
for(int i=0; i<j; ++i)
{
double fact = cblas_ddot(mm,(*this)[j],1,(*this)[i],1);
cblas_daxpy(mm,-fact,(*this)[i],1,(*this)[j],1);
}
double norm = cblas_dnrm2(mm,(*this)[j],1);
if(norm==0.) laerror("zero vector in orthonormalize");
cblas_dscal(mm,1./norm,(*this)[j],1);
}
}
else //vectors are columns
{
for(int j=0; j<mm; ++j)
{
for(int i=0; i<j; ++i)
{
double fact = cblas_ddot(nn,&(*this)[0][j],mm,&(*this)[0][i],mm);
cblas_daxpy(nn,-fact,&(*this)[0][i],mm,&(*this)[0][j],mm);
}
double norm = cblas_dnrm2(nn,&(*this)[0][j],mm);
if(norm==0.) laerror("zero vector in orthonormalize");
cblas_dscal(nn,1./norm,&(*this)[0][j],mm);
}
}
}
}
2005-02-02 15:49:33 +01:00
2004-03-17 04:07:21 +01:00
2006-09-10 22:06:44 +02:00
//////////////////////////////////////////////////////////////////////////////
//// forced instantization in the corresponding object file
template class NRMat<double>;
template class NRMat<complex<double> >;
2009-10-19 21:38:57 +02:00
template class NRMat<long long>;
template class NRMat<long>;
2006-09-10 22:06:44 +02:00
template class NRMat<int>;
template class NRMat<short>;
template class NRMat<char>;
template class NRMat<unsigned char>;
2009-10-19 21:38:57 +02:00
template class NRMat<unsigned short>;
2006-09-10 22:06:44 +02:00
template class NRMat<unsigned int>;
template class NRMat<unsigned long>;
2009-10-19 21:38:57 +02:00
template class NRMat<unsigned long long>;
2006-09-10 22:06:44 +02:00
2009-11-12 22:01:19 +01:00
}//namespace