*** empty log message ***

This commit is contained in:
jiri
2005-02-18 22:08:15 +00:00
parent 02a868e8aa
commit 6f42b9bb18
15 changed files with 195 additions and 208 deletions

61
vec.h
View File

@@ -1,24 +1,8 @@
#ifndef _LA_VEC_H_
#define _LA_VEC_H_
#include "laerror.h"
extern "C" {
#include "cblas.h"
}
#include <stdio.h>
#include <complex>
#include <string.h>
#include <iostream>
using namespace std;
#include "la_traits.h"
template <typename T> class NRVec;
template <typename T> class NRSMat;
template <typename T> class NRMat;
template <typename T> class SparseMat;
//////////////////////////////////////////////////////////////////////////////
// Forward declarations
template <typename T> void lawritemat(FILE *file,const T *a,int r,int c,
@@ -43,9 +27,6 @@ template<class T> \
inline const NR##E<T> NR##E<T>::operator X(const NR##E<T> &a) const \
{ return NR##E(*this) X##= a; }
#include "smat.h"
#include "mat.h"
// NRVec class
template <typename T>
@@ -84,9 +65,14 @@ public:
inline const NRVec operator+(const T &a) const;
inline const NRVec operator-(const T &a) const;
inline const NRVec operator*(const T &a) const;
inline const T operator*(const NRVec &rhs) const; //scalar product -> ddot
inline const NRVec operator*(const NRSMat<T> & S) const;
const NRVec operator*(const NRMat<T> &mat) const;
inline const T operator*(const NRVec &rhs) const; //scalar product -> dot
inline const T dot(const NRVec &rhs) const {return *this * rhs;}; //@@@for complex do conjugate
void gemv(const T beta, const NRMat<T> &a, const char trans, const T alpha, const NRVec &x);
void gemv(const T beta, const NRSMat<T> &a, const char trans /*just for compatibility*/, const T alpha, const NRVec &x);
void gemv(const T beta, const SparseMat<T> &a, const char trans, const T alpha, const NRVec &x);
const NRVec operator*(const NRMat<T> &mat) const {NRVec<T> result(mat.ncols()); result.gemv((T)0,mat,'t',(T)1,*this); return result;};
const NRVec operator*(const NRSMat<T> &mat) const {NRVec<T> result(mat.ncols()); result.gemv((T)0,mat,'t',(T)1,*this); return result;};
const NRVec operator*(const SparseMat<T> &mat) const {NRVec<T> result(mat.ncols()); result.gemv((T)0,mat,'t',(T)1,*this); return result;};
const NRMat<T> operator|(const NRVec<T> &rhs) const;
inline const T sum() const; //sum of its elements
inline const T dot(const T *a, const int stride=1) const; // ddot with a stride-vector
@@ -98,8 +84,6 @@ public:
~NRVec();
void axpy(const T alpha, const NRVec &x); // this+= a*x
void axpy(const T alpha, const T *x, const int stride=1); // this+= a*x
void gemv(const T beta, const NRMat<T> &a, const char trans,
const T alpha, const NRVec &x);
void copyonwrite();
void resize(const int n);
void get(int fd, bool dimensions=1);
@@ -112,11 +96,14 @@ public:
void fscanf(FILE *f, const char *format);
//sparse matrix concerning members
explicit NRVec(const SparseMat<T> &rhs); // dense from sparse matrix with one of dimensions =1
const NRVec operator*(const SparseMat<T> &mat) const; //vector*matrix
inline void simplify() {}; //just for compatibility with sparse ones
void gemv(const T beta, const SparseMat<T> &a, const char trans, const T alpha, const NRVec &x);
};
//due to mutual includes this has to be after full class declaration
#include "mat.h"
#include "smat.h"
#include "sparsemat.h"
template <typename T> ostream & operator<<(ostream &s, const NRVec<T> &x);
template <typename T> istream & operator>>(istream &s, NRVec<T> &x);
@@ -313,28 +300,36 @@ inline NRVec<T> & NRVec<T>::operator*=(const T &a)
inline const double NRVec<double>::operator*(const NRVec<double> &rhs) const
{
#ifdef DEBUG
if (nn != rhs.nn) laerror("ddot of incompatible vectors");
if (nn != rhs.nn) laerror("dot of incompatible vectors");
#endif
return cblas_ddot(nn, v, 1, rhs.v, 1);
return cblas_ddot(nn, v, 1, rhs.v, 1);
}
inline const complex<double>
NRVec< complex<double> >::operator*(const NRVec< complex<double> > &rhs) const
{
#ifdef DEBUG
if (nn != rhs.nn) laerror("ddot of incompatible vectors");
if (nn != rhs.nn) laerror("dot of incompatible vectors");
#endif
complex<double> dot;
cblas_zdotc_sub(nn, (void *)v, 1, (void *)rhs.v, 1, (void *)(&dot));
return dot;
}
// Vec * SMat = SMat * Vec
template <typename T>
inline const NRVec<T> NRVec<T>::operator*(const NRSMat<T> & S) const
template<typename T>
inline const T NRVec<T>::operator*(const NRVec<T> &rhs) const
{
return S * (*this);
#ifdef DEBUG
if (nn != rhs.nn) laerror("dot of incompatible vectors");
#endif
T dot = 0;
for(int i=0; i<nn; ++i) dot+= v[i]*rhs.v[i];
return dot;
}
// Sum of elements
inline const double NRVec<double>::sum() const
{