LA_library/mat.h
2004-03-17 03:07:21 +00:00

347 lines
8.6 KiB
C++

#ifndef _LA_MAT_H_
#define _LA_MAT_H_
#include "vec.h"
#include "smat.h"
template <typename T>
class NRMat {
protected:
int nn;
int mm;
#ifdef MATPTR
T **v;
#else
T *v;
#endif
int *count;
public:
friend class NRVec<T>;
friend class NRSMat<T>;
inline NRMat() : nn(0), mm(0), v(0), count(0) {};
inline NRMat(const int n, const int m);
inline NRMat(const T &a, const int n, const int m);
NRMat(const T *a, const int n, const int m);
inline NRMat(const NRMat &rhs);
explicit NRMat(const NRSMat<T> &rhs);
#ifndef MATPTR
NRMat(const NRVec<T> &rhs, const int n, const int m);
#endif
~NRMat();
inline int getcount() const {return count?*count:0;}
NRMat & operator=(const NRMat &rhs); //assignment
NRMat & operator=(const T &a); //assign a to diagonal
NRMat & operator|=(const NRMat &rhs); //assignment to a new copy
NRMat & operator+=(const T &a); //add diagonal
NRMat & operator-=(const T &a); //substract diagonal
NRMat & operator*=(const T &a); //multiply by a scalar
NRMat & operator+=(const NRMat &rhs);
NRMat & operator-=(const NRMat &rhs);
NRMat & operator+=(const NRSMat<T> &rhs);
NRMat & operator-=(const NRSMat<T> &rhs);
const NRMat operator-() const; //unary minus
inline const NRMat operator+(const T &a) const;
inline const NRMat operator-(const T &a) const;
inline const NRMat operator*(const T &a) const;
inline const NRMat operator+(const NRMat &rhs) const;
inline const NRMat operator-(const NRMat &rhs) const;
inline const NRMat operator+(const NRSMat<T> &rhs) const;
inline const NRMat operator-(const NRSMat<T> &rhs) const;
const T dot(const NRMat &rhs) const; // scalar product of Mat.Mat
const NRMat operator*(const NRMat &rhs) const; // Mat * Mat
void diagmultl(const NRVec<T> &rhs); //multiply by a diagonal matrix from L
void diagmultr(const NRVec<T> &rhs); //multiply by a diagonal matrix from R
const NRMat operator*(const NRSMat<T> &rhs) const; // Mat * Smat
const NRMat operator&(const NRMat &rhs) const; // direct sum
const NRMat operator|(const NRMat<T> &rhs) const; // direct product
const NRVec<T> operator*(const NRVec<T> &rhs) const; // Mat * Vec
const NRVec<T> rsum() const; //sum of rows
const NRVec<T> csum() const; //sum of columns
inline T* operator[](const int i); //subscripting: pointer to row i
inline const T* operator[](const int i) const;
inline T& operator()(const int i, const int j); // (i,j) subscripts
inline const T& operator()(const int i, const int j) const;
inline int nrows() const;
inline int ncols() const;
void copyonwrite();
void resize(const int n, const int m);
inline operator T*(); //get a pointer to the data
inline operator const T*() const;
NRMat & transposeme(); // square matrices only
NRMat & conjugateme(); // square matrices only
const NRMat transpose(bool conj=false) const;
const NRMat conjugate() const;
void gemm(const T &beta, const NRMat &a, const char transa, const NRMat &b,
const char transb, const T &alpha);//this = alpha*op( A )*op( B ) + beta*this
/*
void strassen(const T beta, const NRMat &a, const char transa, const NRMat &b,
const char transb, const T alpha);//this := alpha*op( A )*op( B ) + beta*this
void s_cutoff(const int,const int,const int,const int) const;
*/
void fprintf(FILE *f, const char *format, const int modulo) const;
void fscanf(FILE *f, const char *format);
const double norm(const T scalar=(T)0) const;
void axpy(const T alpha, const NRMat &x); // this += a*x
inline const T amax() const;
const T trace() const;
//members concerning sparse matrix
explicit NRMat(const SparseMat<T> &rhs); // dense from sparse
NRMat & operator+=(const SparseMat<T> &rhs);
NRMat & operator-=(const SparseMat<T> &rhs);
inline void simplify() {}; //just for compatibility with sparse ones
//Strassen's multiplication (better than n^3, analogous syntax to gemm)
void strassen(const T beta, const NRMat &a, const char transa, const NRMat &b, const char transb, const T alpha);//this := alpha*op( A )*op( B ) + beta*this
void s_cutoff(const int,const int,const int,const int) const;
};
// ctors
template <typename T>
NRMat<T>::NRMat(const int n, const int m) : nn(n), mm(m), count(new int)
{
*count = 1;
#ifdef MATPTR
v = new T*[n];
v[0] = new T[m*n];
for (int i=1; i<n; i++) v[i] = v[i-1] + m;
#else
v = new T[m*n];
#endif
}
template <typename T>
NRMat<T>::NRMat(const T &a, const int n, const int m) : nn(n), mm(m), count(new int)
{
int i;
T *p;
*count = 1;
#ifdef MATPTR
v = new T*[n];
p = v[0] = new T[m*n];
for (int i=1; i<n; i++) v[i] = v[i-1] + m;
#else
p = v = new T[m*n];
#endif
if (a != (T)0)
for (i=0; i< n*m; i++) *p++ = a;
else
memset(p, 0, n*m*sizeof(T));
}
template <typename T>
NRMat<T>::NRMat(const T *a, const int n, const int m) : nn(n), mm(m), count(new int)
{
*count = 1;
#ifdef MATPTR
v = new T*[n];
v[0] = new T[m*n];
for (int i=1; i<n; i++) v[i] = v[i-1] + m;
memcpy(v[0], a, n*m*sizeof(T));
#else
v = new T[m*n];
memcpy(v, a, n*m*sizeof(T));
#endif
}
template <typename T>
NRMat<T>::NRMat(const NRMat &rhs)
{
nn = rhs.nn;
mm = rhs.mm;
count = rhs.count;
v = rhs.v;
if (count) ++(*count);
}
template <typename T>
NRMat<T>::NRMat(const NRSMat<T> &rhs)
{
int i;
nn = mm = rhs.nrows();
count = new int;
*count = 1;
#ifdef MATPTR
v = new T*[nn];
v[0] = new T[mm*nn];
for (int i=1; i<nn; i++) v[i] = v[i-1] + mm;
#else
v = new T[mm*nn];
#endif
int j, k = 0;
#ifdef MATPTR
for (i=0; i<nn; i++)
for (j=0; j<=i; j++) v[i][j] = v[j][i] = rhs[k++];
#else
for (i=0; i<nn; i++)
for (j=0; j<=i; j++) v[i*nn+j] = v[j*nn+i] = rhs[k++];
#endif
}
#ifndef MATPTR
template <typename T>
NRMat<T>::NRMat(const NRVec<T> &rhs, const int n, const int m)
{
#ifdef DEBUG
if (n*m != rhs.nn) laerror("matrix dimensions incompatible with vector length");
#endif
nn = n;
mm = m;
count = rhs.count;
v = rhs.v;
(*count)++;
}
#endif
// Mat + Smat
template <typename T>
inline const NRMat<T> NRMat<T>::operator+(const NRSMat<T> &rhs) const
{
return NRMat<T>(*this) += rhs;
}
// Mat - Smat
template <typename T>
inline const NRMat<T> NRMat<T>::operator-(const NRSMat<T> &rhs) const
{
return NRMat<T>(*this) -= rhs;
}
// Mat[i] : pointer to the first element of i-th row
template <typename T>
inline T* NRMat<T>::operator[](const int i)
{
#ifdef DEBUG
if (*count != 1) laerror("Mat lval use of [] with count > 1");
if (i<0 || i>=nn) laerror("Mat [] out of range");
if (!v) laerror("[] for unallocated Mat");
#endif
#ifdef MATPTR
return v[i];
#else
return v+i*mm;
#endif
}
template <typename T>
inline const T* NRMat<T>::operator[](const int i) const
{
#ifdef DEBUG
if (i<0 || i>=nn) laerror("Mat [] out of range");
if (!v) laerror("[] for unallocated Mat");
#endif
#ifdef MATPTR
return v[i];
#else
return v+i*mm;
#endif
}
// Mat(i,j) reference to the matrix element M_{ij}
template <typename T>
inline T & NRMat<T>::operator()(const int i, const int j)
{
#ifdef DEBUG
if (*count != 1) laerror("Mat lval use of (,) with count > 1");
if (i<0 || i>=nn || j<0 || j>mm) laerror("Mat (,) out of range");
if (!v) laerror("(,) for unallocated Mat");
#endif
#ifdef MATPTR
return v[i][j];
#else
return v[i*mm+j];
#endif
}
template <typename T>
inline const T & NRMat<T>::operator()(const int i, const int j) const
{
#ifdef DEBUG
if (i<0 || i>=nn || j<0 || j>mm) laerror("Mat (,) out of range");
if (!v) laerror("(,) for unallocated Mat");
#endif
#ifdef MATPTR
return v[i][j];
#else
return v[i*mm+j];
#endif
}
// number of rows
template <typename T>
inline int NRMat<T>::nrows() const
{
return nn;
}
// number of columns
template <typename T>
inline int NRMat<T>::ncols() const
{
return mm;
}
// reference pointer to Mat
template <typename T>
inline NRMat<T>::operator T* ()
{
#ifdef DEBUG
if (!v) laerror("unallocated Mat in operator T*");
#endif
#ifdef MATPTR
return v[0];
#else
return v;
#endif
}
template <typename T>
inline NRMat<T>::operator const T* () const
{
#ifdef DEBUG
if (!v) laerror("unallocated Mat in operator T*");
#endif
#ifdef MATPTR
return v[0];
#else
return v;
#endif
}
// max element of Mat
inline const double NRMat<double>::amax() const
{
#ifdef MATPTR
return v[0][cblas_idamax(nn*mm, v[0], 1)];
#else
return v[cblas_idamax(nn*mm, v, 1)];
#endif
}
inline const complex<double> NRMat< complex<double> >::amax() const
{
#ifdef MATPTR
return v[0][cblas_izamax(nn*mm, (void *)v[0], 1)];
#else
return v[cblas_izamax(nn*mm, (void *)v, 1)];
#endif
}
// I/O
template <typename T> extern ostream& operator<<(ostream &s, const NRMat<T> &x);
template <typename T> extern istream& operator>>(istream &s, NRMat<T> &x);
// generate operators: Mat + a, a + Mat, Mat * a
NRVECMAT_OPER(Mat,+)
NRVECMAT_OPER(Mat,-)
NRVECMAT_OPER(Mat,*)
// generate Mat + Mat, Mat - Mat
NRVECMAT_OPER2(Mat,+)
NRVECMAT_OPER2(Mat,-)
#endif /* _LA_MAT_H_ */