LA_library/mat.h

556 lines
14 KiB
C
Raw Normal View History

2004-03-17 04:07:21 +01:00
#ifndef _LA_MAT_H_
#define _LA_MAT_H_
2005-02-14 01:10:07 +01:00
#include "la_traits.h"
2004-03-17 04:07:21 +01:00
template <typename T>
class NRMat {
protected:
int nn;
int mm;
#ifdef MATPTR
T **v;
#else
T *v;
#endif
int *count;
public:
friend class NRVec<T>;
friend class NRSMat<T>;
inline NRMat() : nn(0), mm(0), v(0), count(0) {};
inline NRMat(const int n, const int m);
inline NRMat(const T &a, const int n, const int m);
NRMat(const T *a, const int n, const int m);
inline NRMat(const NRMat &rhs);
explicit NRMat(const NRSMat<T> &rhs);
2005-11-20 14:46:00 +01:00
#ifdef MATPTR
explicit NRMat(const NRVec<T> &rhs, const int n, const int m) :NRMat(&rhs[0][0],n,m) {};
#else
explicit NRMat(const NRVec<T> &rhs, const int n, const int m);
2004-03-17 04:07:21 +01:00
#endif
~NRMat();
2005-02-14 01:10:07 +01:00
#ifdef MATPTR
2005-09-06 17:55:07 +02:00
const bool operator!=(const NRMat &rhs) const {if(nn!=rhs.nn || mm!=rhs.mm) return 1; return LA_traits<T>::gencmp(v[0],rhs.v[0],nn*mm);} //memcmp for scalars else elementwise
2005-02-14 01:10:07 +01:00
#else
2005-09-06 17:55:07 +02:00
const bool operator!=(const NRMat &rhs) const {if(nn!=rhs.nn || mm!=rhs.mm) return 1; return LA_traits<T>::gencmp(v,rhs.v,nn*mm);} //memcmp for scalars else elementwise
2005-02-14 01:10:07 +01:00
#endif
const bool operator==(const NRMat &rhs) const {return !(*this != rhs);};
2004-03-17 04:07:21 +01:00
inline int getcount() const {return count?*count:0;}
NRMat & operator=(const NRMat &rhs); //assignment
NRMat & operator=(const T &a); //assign a to diagonal
NRMat & operator|=(const NRMat &rhs); //assignment to a new copy
NRMat & operator+=(const T &a); //add diagonal
NRMat & operator-=(const T &a); //substract diagonal
NRMat & operator*=(const T &a); //multiply by a scalar
NRMat & operator+=(const NRMat &rhs);
NRMat & operator-=(const NRMat &rhs);
NRMat & operator+=(const NRSMat<T> &rhs);
NRMat & operator-=(const NRSMat<T> &rhs);
const NRMat operator-() const; //unary minus
inline const NRMat operator+(const T &a) const;
inline const NRMat operator-(const T &a) const;
inline const NRMat operator*(const T &a) const;
inline const NRMat operator+(const NRMat &rhs) const;
inline const NRMat operator-(const NRMat &rhs) const;
inline const NRMat operator+(const NRSMat<T> &rhs) const;
inline const NRMat operator-(const NRSMat<T> &rhs) const;
2005-02-18 23:08:15 +01:00
const T dot(const NRMat &rhs) const; // scalar product of Mat.Mat//@@@for complex do conjugate
2004-03-17 04:07:21 +01:00
const NRMat operator*(const NRMat &rhs) const; // Mat * Mat
2005-02-14 01:10:07 +01:00
const NRMat oplus(const NRMat &rhs) const; //direct sum
const NRMat otimes(const NRMat &rhs) const; //direct product
2004-03-17 04:07:21 +01:00
void diagmultl(const NRVec<T> &rhs); //multiply by a diagonal matrix from L
void diagmultr(const NRVec<T> &rhs); //multiply by a diagonal matrix from R
const NRMat operator*(const NRSMat<T> &rhs) const; // Mat * Smat
const NRMat operator&(const NRMat &rhs) const; // direct sum
const NRMat operator|(const NRMat<T> &rhs) const; // direct product
2005-02-18 23:08:15 +01:00
const NRVec<T> operator*(const NRVec<T> &rhs) const {NRVec<T> result(nn); result.gemv((T)0,*this,'n',(T)1,rhs); return result;}; // Mat * Vec
2004-03-17 04:07:21 +01:00
const NRVec<T> rsum() const; //sum of rows
const NRVec<T> csum() const; //sum of columns
2005-02-04 15:31:42 +01:00
void diagonalof(NRVec<T> &, const bool divide=0) const; //get diagonal
2004-03-17 04:07:21 +01:00
inline T* operator[](const int i); //subscripting: pointer to row i
inline const T* operator[](const int i) const;
inline T& operator()(const int i, const int j); // (i,j) subscripts
inline const T& operator()(const int i, const int j) const;
inline int nrows() const;
inline int ncols() const;
2005-02-18 23:08:15 +01:00
inline int size() const;
2005-09-11 22:04:24 +02:00
void get(int fd, bool dimensions=1, bool transposed=false);
void put(int fd, bool dimensions=1, bool transposed=false) const;
2004-03-17 04:07:21 +01:00
void copyonwrite();
void resize(const int n, const int m);
inline operator T*(); //get a pointer to the data
inline operator const T*() const;
2005-02-18 23:08:15 +01:00
NRMat & transposeme(int n=0); // square matrices only
2004-03-17 04:07:21 +01:00
NRMat & conjugateme(); // square matrices only
const NRMat transpose(bool conj=false) const;
const NRMat conjugate() const;
2005-09-11 22:04:24 +02:00
const NRMat submatrix(const int fromrow, const int torow, const int fromcol, const int tocol) const; //there is also independent less efficient routine for generally indexed submatrix
2004-03-17 04:07:21 +01:00
void gemm(const T &beta, const NRMat &a, const char transa, const NRMat &b,
const char transb, const T &alpha);//this = alpha*op( A )*op( B ) + beta*this
void fprintf(FILE *f, const char *format, const int modulo) const;
void fscanf(FILE *f, const char *format);
const double norm(const T scalar=(T)0) const;
void axpy(const T alpha, const NRMat &x); // this += a*x
inline const T amax() const;
const T trace() const;
//members concerning sparse matrix
explicit NRMat(const SparseMat<T> &rhs); // dense from sparse
NRMat & operator+=(const SparseMat<T> &rhs);
NRMat & operator-=(const SparseMat<T> &rhs);
2005-09-06 17:55:07 +02:00
void gemm(const T &beta, const SparseMat<T> &a, const char transa, const NRMat &b, const char transb, const T &alpha);//this = alpha*op( A )*op( B ) + beta*this
2004-03-17 04:07:21 +01:00
inline void simplify() {}; //just for compatibility with sparse ones
2005-02-18 23:08:15 +01:00
bool issymmetric() const {return 0;};
2005-10-12 13:14:55 +02:00
#ifndef NO_STRASSEN
2004-03-17 04:07:21 +01:00
//Strassen's multiplication (better than n^3, analogous syntax to gemm)
void strassen(const T beta, const NRMat &a, const char transa, const NRMat &b, const char transb, const T alpha);//this := alpha*op( A )*op( B ) + beta*this
void s_cutoff(const int,const int,const int,const int) const;
2005-10-12 13:14:55 +02:00
#endif
2004-03-17 04:07:21 +01:00
};
2005-02-18 23:08:15 +01:00
//due to mutual includes this has to be after full class declaration
#include "vec.h"
#include "smat.h"
#include "sparsemat.h"
2004-03-17 04:07:21 +01:00
// ctors
template <typename T>
NRMat<T>::NRMat(const int n, const int m) : nn(n), mm(m), count(new int)
{
*count = 1;
#ifdef MATPTR
v = new T*[n];
v[0] = new T[m*n];
for (int i=1; i<n; i++) v[i] = v[i-1] + m;
#else
v = new T[m*n];
#endif
}
template <typename T>
NRMat<T>::NRMat(const T &a, const int n, const int m) : nn(n), mm(m), count(new int)
{
int i;
T *p;
*count = 1;
#ifdef MATPTR
v = new T*[n];
p = v[0] = new T[m*n];
for (int i=1; i<n; i++) v[i] = v[i-1] + m;
#else
p = v = new T[m*n];
#endif
if (a != (T)0)
for (i=0; i< n*m; i++) *p++ = a;
else
memset(p, 0, n*m*sizeof(T));
}
template <typename T>
NRMat<T>::NRMat(const T *a, const int n, const int m) : nn(n), mm(m), count(new int)
{
*count = 1;
#ifdef MATPTR
v = new T*[n];
v[0] = new T[m*n];
for (int i=1; i<n; i++) v[i] = v[i-1] + m;
memcpy(v[0], a, n*m*sizeof(T));
#else
v = new T[m*n];
memcpy(v, a, n*m*sizeof(T));
#endif
}
template <typename T>
NRMat<T>::NRMat(const NRMat &rhs)
{
nn = rhs.nn;
mm = rhs.mm;
count = rhs.count;
v = rhs.v;
if (count) ++(*count);
}
template <typename T>
NRMat<T>::NRMat(const NRSMat<T> &rhs)
{
int i;
nn = mm = rhs.nrows();
count = new int;
*count = 1;
#ifdef MATPTR
v = new T*[nn];
v[0] = new T[mm*nn];
for (int i=1; i<nn; i++) v[i] = v[i-1] + mm;
#else
v = new T[mm*nn];
#endif
int j, k = 0;
#ifdef MATPTR
for (i=0; i<nn; i++)
for (j=0; j<=i; j++) v[i][j] = v[j][i] = rhs[k++];
#else
for (i=0; i<nn; i++)
for (j=0; j<=i; j++) v[i*nn+j] = v[j*nn+i] = rhs[k++];
#endif
}
#ifndef MATPTR
template <typename T>
NRMat<T>::NRMat(const NRVec<T> &rhs, const int n, const int m)
{
#ifdef DEBUG
if (n*m != rhs.nn) laerror("matrix dimensions incompatible with vector length");
#endif
nn = n;
mm = m;
count = rhs.count;
v = rhs.v;
(*count)++;
}
#endif
// Mat + Smat
template <typename T>
inline const NRMat<T> NRMat<T>::operator+(const NRSMat<T> &rhs) const
{
return NRMat<T>(*this) += rhs;
}
// Mat - Smat
template <typename T>
inline const NRMat<T> NRMat<T>::operator-(const NRSMat<T> &rhs) const
{
return NRMat<T>(*this) -= rhs;
}
// Mat[i] : pointer to the first element of i-th row
template <typename T>
inline T* NRMat<T>::operator[](const int i)
{
#ifdef DEBUG
if (*count != 1) laerror("Mat lval use of [] with count > 1");
if (i<0 || i>=nn) laerror("Mat [] out of range");
if (!v) laerror("[] for unallocated Mat");
#endif
#ifdef MATPTR
return v[i];
#else
return v+i*mm;
#endif
}
template <typename T>
inline const T* NRMat<T>::operator[](const int i) const
{
#ifdef DEBUG
if (i<0 || i>=nn) laerror("Mat [] out of range");
if (!v) laerror("[] for unallocated Mat");
#endif
#ifdef MATPTR
return v[i];
#else
return v+i*mm;
#endif
}
// Mat(i,j) reference to the matrix element M_{ij}
template <typename T>
inline T & NRMat<T>::operator()(const int i, const int j)
{
#ifdef DEBUG
if (*count != 1) laerror("Mat lval use of (,) with count > 1");
if (i<0 || i>=nn || j<0 || j>mm) laerror("Mat (,) out of range");
if (!v) laerror("(,) for unallocated Mat");
#endif
#ifdef MATPTR
return v[i][j];
#else
return v[i*mm+j];
#endif
}
template <typename T>
inline const T & NRMat<T>::operator()(const int i, const int j) const
{
#ifdef DEBUG
if (i<0 || i>=nn || j<0 || j>mm) laerror("Mat (,) out of range");
if (!v) laerror("(,) for unallocated Mat");
#endif
#ifdef MATPTR
return v[i][j];
#else
return v[i*mm+j];
#endif
}
// number of rows
template <typename T>
inline int NRMat<T>::nrows() const
{
return nn;
}
// number of columns
template <typename T>
inline int NRMat<T>::ncols() const
{
return mm;
}
2005-02-18 23:08:15 +01:00
template <typename T>
inline int NRMat<T>::size() const
{
return nn*mm;
}
2004-03-17 04:07:21 +01:00
// reference pointer to Mat
template <typename T>
inline NRMat<T>::operator T* ()
{
#ifdef DEBUG
if (!v) laerror("unallocated Mat in operator T*");
#endif
#ifdef MATPTR
return v[0];
#else
return v;
#endif
}
template <typename T>
inline NRMat<T>::operator const T* () const
{
#ifdef DEBUG
if (!v) laerror("unallocated Mat in operator T*");
#endif
#ifdef MATPTR
return v[0];
#else
return v;
#endif
}
// max element of Mat
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
inline const double NRMat<double>::amax() const
{
#ifdef MATPTR
return v[0][cblas_idamax(nn*mm, v[0], 1)];
#else
return v[cblas_idamax(nn*mm, v, 1)];
#endif
}
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
inline const complex<double> NRMat< complex<double> >::amax() const
{
#ifdef MATPTR
return v[0][cblas_izamax(nn*mm, (void *)v[0], 1)];
#else
return v[cblas_izamax(nn*mm, (void *)v, 1)];
#endif
}
2004-03-17 17:39:07 +01:00
//basi stuff to be available for any type ... must be in .h
// dtor
template <typename T>
NRMat<T>::~NRMat()
{
if (!count) return;
if (--(*count) <= 0) {
if (v) {
#ifdef MATPTR
delete[] (v[0]);
#endif
delete[] v;
}
delete count;
}
}
// assign NRMat = NRMat
template <typename T>
NRMat<T> & NRMat<T>::operator=(const NRMat<T> &rhs)
{
2005-02-01 00:08:03 +01:00
if (this !=&rhs)
{
if (count)
if (--(*count) ==0 ) {
2004-03-17 17:39:07 +01:00
#ifdef MATPTR
delete[] (v[0]);
#endif
delete[] v;
delete count;
2005-02-01 00:08:03 +01:00
}
2004-03-17 17:39:07 +01:00
v = rhs.v;
nn = rhs.nn;
mm = rhs.mm;
count = rhs.count;
2004-03-24 17:25:47 +01:00
if (count) (*count)++;
2005-02-01 00:08:03 +01:00
}
2004-03-17 17:39:07 +01:00
return *this;
}
// Explicit deep copy of NRmat
template <typename T>
NRMat<T> & NRMat<T>::operator|=(const NRMat<T> &rhs)
{
if (this == &rhs) return *this;
#ifdef DEBUG
if (!rhs.v) laerror("unallocated rhs in Mat operator |=");
#endif
if (count)
if (*count > 1) {
--(*count);
nn = 0;
mm = 0;
count = 0;
v = 0;
}
if (nn != rhs.nn || mm != rhs.mm) {
if (v) {
#ifdef MATPTR
delete[] (v[0]);
#endif
delete[] (v);
v = 0;
}
nn = rhs.nn;
mm = rhs.mm;
}
if (!v) {
#ifdef MATPTR
v = new T*[nn];
v[0] = new T[mm*nn];
#else
v = new T[mm*nn];
#endif
}
#ifdef MATPTR
for (int i=1; i< nn; i++) v[i] = v[i-1] + mm;
memcpy(v[0], rhs.v[0], nn*mm*sizeof(T));
#else
memcpy(v, rhs.v, nn*mm*sizeof(T));
#endif
if (!count) count = new int;
*count = 1;
return *this;
}
// make detach Mat and make it's own deep copy
template <typename T>
void NRMat<T>::copyonwrite()
{
#ifdef DEBUG
if (!count) laerror("Mat::copyonwrite of undefined matrix");
#endif
if (*count > 1) {
(*count)--;
count = new int;
*count = 1;
#ifdef MATPTR
T **newv = new T*[nn];
newv[0] = new T[mm*nn];
memcpy(newv[0], v[0], mm*nn*sizeof(T));
v = newv;
for (int i=1; i< nn; i++) v[i] = v[i-1] + mm;
#else
T *newv = new T[mm*nn];
memcpy(newv, v, mm*nn*sizeof(T));
v = newv;
#endif
}
}
template <typename T>
void NRMat<T>::resize(const int n, const int m)
{
#ifdef DEBUG
2005-02-14 01:10:07 +01:00
if (n<0 || m<0 || n>0 && m==0 || n==0 && m>0) laerror("illegal dimensions in Mat::resize()");
2004-03-17 17:39:07 +01:00
#endif
if (count)
2005-02-14 01:10:07 +01:00
{
if(n==0 && m==0)
{
if(--(*count) <= 0) {
#ifdef MATPTR
if(v) delete[] (v[0]);
#endif
if(v) delete[] v;
delete count;
}
count=0;
nn=mm=0;
v=0;
return;
}
2004-03-17 17:39:07 +01:00
if (*count > 1) {
(*count)--;
count = 0;
v = 0;
nn = 0;
mm = 0;
}
2005-02-14 01:10:07 +01:00
}
2004-03-17 17:39:07 +01:00
if (!count) {
count = new int;
*count = 1;
nn = n;
mm = m;
#ifdef MATPTR
v = new T*[nn];
v[0] = new T[m*n];
for (int i=1; i< n; i++) v[i] = v[i-1] + m;
#else
v = new T[m*n];
#endif
return;
}
// At this point *count = 1, check if resize is necessary
if (n!=nn || m!=mm) {
nn = n;
mm = m;
#ifdef MATPTR
delete[] (v[0]);
#endif
delete[] v;
#ifdef MATPTR
v = new T*[nn];
v[0] = new T[m*n];
for (int i=1; i< n; i++) v[i] = v[i-1] + m;
#else
v = new T[m*n];
#endif
}
}
2005-02-06 15:01:27 +01:00
2004-03-17 04:07:21 +01:00
// I/O
template <typename T> extern ostream& operator<<(ostream &s, const NRMat<T> &x);
template <typename T> extern istream& operator>>(istream &s, NRMat<T> &x);
// generate operators: Mat + a, a + Mat, Mat * a
NRVECMAT_OPER(Mat,+)
NRVECMAT_OPER(Mat,-)
NRVECMAT_OPER(Mat,*)
// generate Mat + Mat, Mat - Mat
NRVECMAT_OPER2(Mat,+)
NRVECMAT_OPER2(Mat,-)
#endif /* _LA_MAT_H_ */