*** empty log message ***
This commit is contained in:
61
vec.h
61
vec.h
@@ -1,24 +1,8 @@
|
||||
#ifndef _LA_VEC_H_
|
||||
#define _LA_VEC_H_
|
||||
|
||||
#include "laerror.h"
|
||||
extern "C" {
|
||||
#include "cblas.h"
|
||||
}
|
||||
#include <stdio.h>
|
||||
#include <complex>
|
||||
#include <string.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
#include "la_traits.h"
|
||||
|
||||
template <typename T> class NRVec;
|
||||
template <typename T> class NRSMat;
|
||||
template <typename T> class NRMat;
|
||||
template <typename T> class SparseMat;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// Forward declarations
|
||||
template <typename T> void lawritemat(FILE *file,const T *a,int r,int c,
|
||||
@@ -43,9 +27,6 @@ template<class T> \
|
||||
inline const NR##E<T> NR##E<T>::operator X(const NR##E<T> &a) const \
|
||||
{ return NR##E(*this) X##= a; }
|
||||
|
||||
#include "smat.h"
|
||||
#include "mat.h"
|
||||
|
||||
|
||||
// NRVec class
|
||||
template <typename T>
|
||||
@@ -84,9 +65,14 @@ public:
|
||||
inline const NRVec operator+(const T &a) const;
|
||||
inline const NRVec operator-(const T &a) const;
|
||||
inline const NRVec operator*(const T &a) const;
|
||||
inline const T operator*(const NRVec &rhs) const; //scalar product -> ddot
|
||||
inline const NRVec operator*(const NRSMat<T> & S) const;
|
||||
const NRVec operator*(const NRMat<T> &mat) const;
|
||||
inline const T operator*(const NRVec &rhs) const; //scalar product -> dot
|
||||
inline const T dot(const NRVec &rhs) const {return *this * rhs;}; //@@@for complex do conjugate
|
||||
void gemv(const T beta, const NRMat<T> &a, const char trans, const T alpha, const NRVec &x);
|
||||
void gemv(const T beta, const NRSMat<T> &a, const char trans /*just for compatibility*/, const T alpha, const NRVec &x);
|
||||
void gemv(const T beta, const SparseMat<T> &a, const char trans, const T alpha, const NRVec &x);
|
||||
const NRVec operator*(const NRMat<T> &mat) const {NRVec<T> result(mat.ncols()); result.gemv((T)0,mat,'t',(T)1,*this); return result;};
|
||||
const NRVec operator*(const NRSMat<T> &mat) const {NRVec<T> result(mat.ncols()); result.gemv((T)0,mat,'t',(T)1,*this); return result;};
|
||||
const NRVec operator*(const SparseMat<T> &mat) const {NRVec<T> result(mat.ncols()); result.gemv((T)0,mat,'t',(T)1,*this); return result;};
|
||||
const NRMat<T> operator|(const NRVec<T> &rhs) const;
|
||||
inline const T sum() const; //sum of its elements
|
||||
inline const T dot(const T *a, const int stride=1) const; // ddot with a stride-vector
|
||||
@@ -98,8 +84,6 @@ public:
|
||||
~NRVec();
|
||||
void axpy(const T alpha, const NRVec &x); // this+= a*x
|
||||
void axpy(const T alpha, const T *x, const int stride=1); // this+= a*x
|
||||
void gemv(const T beta, const NRMat<T> &a, const char trans,
|
||||
const T alpha, const NRVec &x);
|
||||
void copyonwrite();
|
||||
void resize(const int n);
|
||||
void get(int fd, bool dimensions=1);
|
||||
@@ -112,11 +96,14 @@ public:
|
||||
void fscanf(FILE *f, const char *format);
|
||||
//sparse matrix concerning members
|
||||
explicit NRVec(const SparseMat<T> &rhs); // dense from sparse matrix with one of dimensions =1
|
||||
const NRVec operator*(const SparseMat<T> &mat) const; //vector*matrix
|
||||
inline void simplify() {}; //just for compatibility with sparse ones
|
||||
void gemv(const T beta, const SparseMat<T> &a, const char trans, const T alpha, const NRVec &x);
|
||||
};
|
||||
|
||||
//due to mutual includes this has to be after full class declaration
|
||||
#include "mat.h"
|
||||
#include "smat.h"
|
||||
#include "sparsemat.h"
|
||||
|
||||
template <typename T> ostream & operator<<(ostream &s, const NRVec<T> &x);
|
||||
template <typename T> istream & operator>>(istream &s, NRVec<T> &x);
|
||||
|
||||
@@ -313,28 +300,36 @@ inline NRVec<T> & NRVec<T>::operator*=(const T &a)
|
||||
inline const double NRVec<double>::operator*(const NRVec<double> &rhs) const
|
||||
{
|
||||
#ifdef DEBUG
|
||||
if (nn != rhs.nn) laerror("ddot of incompatible vectors");
|
||||
if (nn != rhs.nn) laerror("dot of incompatible vectors");
|
||||
#endif
|
||||
return cblas_ddot(nn, v, 1, rhs.v, 1);
|
||||
return cblas_ddot(nn, v, 1, rhs.v, 1);
|
||||
}
|
||||
|
||||
|
||||
inline const complex<double>
|
||||
NRVec< complex<double> >::operator*(const NRVec< complex<double> > &rhs) const
|
||||
{
|
||||
#ifdef DEBUG
|
||||
if (nn != rhs.nn) laerror("ddot of incompatible vectors");
|
||||
if (nn != rhs.nn) laerror("dot of incompatible vectors");
|
||||
#endif
|
||||
complex<double> dot;
|
||||
cblas_zdotc_sub(nn, (void *)v, 1, (void *)rhs.v, 1, (void *)(&dot));
|
||||
return dot;
|
||||
}
|
||||
|
||||
// Vec * SMat = SMat * Vec
|
||||
template <typename T>
|
||||
inline const NRVec<T> NRVec<T>::operator*(const NRSMat<T> & S) const
|
||||
template<typename T>
|
||||
inline const T NRVec<T>::operator*(const NRVec<T> &rhs) const
|
||||
{
|
||||
return S * (*this);
|
||||
#ifdef DEBUG
|
||||
if (nn != rhs.nn) laerror("dot of incompatible vectors");
|
||||
#endif
|
||||
T dot = 0;
|
||||
for(int i=0; i<nn; ++i) dot+= v[i]*rhs.v[i];
|
||||
return dot;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Sum of elements
|
||||
inline const double NRVec<double>::sum() const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user