LA_library/vec.cc

437 lines
9.9 KiB
C++
Raw Normal View History

2004-03-17 04:07:21 +01:00
#include <iostream>
#include "vec.h"
2005-02-14 01:10:07 +01:00
#include <stdlib.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
2005-11-20 14:46:00 +01:00
#include <errno.h>
2005-02-14 01:10:07 +01:00
extern "C" {
extern ssize_t read(int, void *, size_t);
extern ssize_t write(int, const void *, size_t);
}
2004-03-17 04:07:21 +01:00
//////////////////////////////////////////////////////////////////////////////
//// forced instantization in the corespoding object file
#define INSTANTIZE(T) \
template ostream & operator<<(ostream &s, const NRVec< T > &x); \
template istream & operator>>(istream &s, NRVec< T > &x); \
2005-09-11 22:04:24 +02:00
template void NRVec<T>::put(int fd, bool dim, bool transp) const; \
template void NRVec<T>::get(int fd, bool dim, bool transp); \
2005-02-18 23:08:15 +01:00
2004-03-17 04:07:21 +01:00
INSTANTIZE(double)
INSTANTIZE(complex<double>)
2004-03-17 06:34:59 +01:00
INSTANTIZE(int)
2005-02-14 01:10:07 +01:00
INSTANTIZE(unsigned int)
INSTANTIZE(short)
INSTANTIZE(unsigned short)
2004-03-17 06:34:59 +01:00
INSTANTIZE(char)
2005-02-14 01:10:07 +01:00
INSTANTIZE(unsigned char)
2004-03-17 04:07:21 +01:00
template NRVec<double>;
2005-02-14 01:10:07 +01:00
template NRVec<complex<double> >;
2004-03-17 06:34:59 +01:00
template NRVec<char>;
2005-09-05 00:12:09 +02:00
template NRVec<short>;
2005-02-18 23:08:15 +01:00
template NRVec<int>;
2004-03-17 04:07:21 +01:00
/*
* Templates first, specializations for BLAS next
*/
// conversion ctor
#ifndef MATPTR
template <typename T>
NRVec<T>::NRVec(const NRMat<T> &rhs)
{
nn = rhs.nn*rhs.mm;
v = rhs.v;
count = rhs.count;
(*count)++;
}
#endif
2005-02-14 01:10:07 +01:00
// formatted I/O
2004-03-17 04:07:21 +01:00
template <typename T>
ostream & operator<<(ostream &s, const NRVec<T> &x)
{
int i, n;
n = x.size();
s << n << endl;
for(i=0; i<n; i++) s << x[i] << (i == n-1 ? '\n' : ' ');
return s;
}
template <typename T>
istream & operator>>(istream &s, NRVec<T> &x)
{
int i,n;
s >> n;
x.resize(n);
for(i=0; i<n; i++) s >> x[i];
return s;
}
2005-02-14 01:10:07 +01:00
//raw I/O
template <typename T>
2005-09-11 22:04:24 +02:00
void NRVec<T>::put(int fd, bool dim, bool transp) const
2005-02-14 01:10:07 +01:00
{
errno=0;
int pad=1; //align at least 8-byte
if(dim)
{
if(sizeof(int) != write(fd,&nn,sizeof(int))) laerror("cannot write");
if(sizeof(int) != write(fd,&pad,sizeof(int))) laerror("cannot write");
}
LA_traits<T>::multiput(nn,fd,v,dim);
}
template <typename T>
2005-09-11 22:04:24 +02:00
void NRVec<T>::get(int fd, bool dim, bool transp)
2005-02-14 01:10:07 +01:00
{
int nn0[2]; //align at least 8-byte
errno=0;
if(dim)
{
if(2*sizeof(int) != read(fd,&nn0,2*sizeof(int))) laerror("cannot read");
resize(nn0[0]);
}
else
copyonwrite();
LA_traits<T>::multiget(nn,fd,v,dim);
}
2004-03-17 04:07:21 +01:00
// formatted print for NRVec
template<typename T>
void NRVec<T>::fprintf(FILE *file, const char *format, const int modulo) const
{
lawritemat(file, v, 1, nn, format, 1, modulo, 0);
}
// formatted scan for NRVec
2005-02-14 01:10:07 +01:00
template <typename T>
2004-03-17 04:07:21 +01:00
void NRVec<T>::fscanf(FILE *f, const char *format)
{
int n;
if(std::fscanf(f, "%d", &n) != 1) laerror("cannot read vector dimension");
resize(n);
for (int i=0; i<n; i++)
if (std::fscanf(f, format, v+i) != 1)
laerror("cannot read the vector eleemnt");
}
// unary minus
template <typename T>
const NRVec<T> NRVec<T>::operator-() const
{
NRVec<T> result(nn);
for (int i=0; i<nn; i++) result.v[i]= -v[i];
return result;
}
2005-09-06 17:55:07 +02:00
//comparison operators (for lexical order)
template <typename T>
const bool NRVec<T>::operator>(const NRVec &rhs) const
{
int n=nn; if(rhs.nn<n) n=rhs.nn;
for(int i=0; i<n;++i)
{
if(LA_traits<T>::bigger(v[i],rhs.v[i])) return true;
if(LA_traits<T>::smaller(v[i],rhs.v[i])) return false;
}
return nn>rhs.nn;
}
template <typename T>
const bool NRVec<T>::operator<(const NRVec &rhs) const
{
int n=nn; if(rhs.nn<n) n=rhs.nn;
for(int i=0; i<n;++i)
{
if(LA_traits<T>::smaller(v[i],rhs.v[i])) return true;
if(LA_traits<T>::bigger(v[i],rhs.v[i])) return false;
}
return nn<rhs.nn;
}
2004-03-17 04:07:21 +01:00
// axpy call for T = double (not strided)
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRVec<double>::axpy(const double alpha, const NRVec<double> &x)
{
#ifdef DEBUG
if (nn != x.nn) laerror("axpy of incompatible vectors");
#endif
copyonwrite();
cblas_daxpy(nn, alpha, x.v, 1, v, 1);
}
// axpy call for T = complex<double> (not strided)
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRVec< complex<double> >::axpy(const complex<double> alpha,
const NRVec< complex<double> > &x)
{
#ifdef DEBUG
if (nn != x.nn) laerror("axpy of incompatible vectors");
#endif
copyonwrite();
cblas_zaxpy(nn, (void *)(&alpha), (void *)(x.v), 1, (void *)v, 1);
}
// axpy call for T = double (strided)
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRVec<double>::axpy(const double alpha, const double *x, const int stride)
{
copyonwrite();
cblas_daxpy(nn, alpha, x, stride, v, 1);
}
// axpy call for T = complex<double> (strided)
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRVec< complex<double> >::axpy(const complex<double> alpha,
const complex<double> *x, const int stride)
{
copyonwrite();
cblas_zaxpy(nn, (void *)(&alpha), (void *)x, stride, v, 1);
}
// unary minus
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
const NRVec<double> NRVec<double>::operator-() const
{
NRVec<double> result(*this);
result.copyonwrite();
cblas_dscal(nn, -1.0, result.v, 1);
return result;
}
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
const NRVec< complex<double> >
NRVec< complex<double> >::operator-() const
{
NRVec< complex<double> > result(*this);
result.copyonwrite();
cblas_zdscal(nn, -1.0, (void *)(result.v), 1);
return result;
}
// assignment of scalar to every element
template <typename T>
NRVec<T> & NRVec<T>::operator=(const T &a)
{
copyonwrite();
if(a != (T)0)
for (int i=0; i<nn; i++) v[i] = a;
else
memset(v, 0, nn*sizeof(T));
return *this;
}
// Normalization of NRVec<double>
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRVec<double> & NRVec<double>::normalize()
{
double tmp;
tmp = cblas_dnrm2(nn, v, 1);
#ifdef DEBUG
if(!tmp) laerror("normalization of zero vector");
#endif
copyonwrite();
tmp = 1.0/tmp;
cblas_dscal(nn, tmp, v, 1);
return *this;
}
// Normalization of NRVec< complex<double> >
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
NRVec< complex<double> > & NRVec< complex<double> >::normalize()
{
complex<double> tmp;
tmp = cblas_dznrm2(nn, (void *)v, 1);
#ifdef DEBUG
if(!(tmp.real()) && !(tmp.imag())) laerror("normalization of zero vector");
#endif
copyonwrite();
tmp = 1.0/tmp;
cblas_zscal(nn, (void *)(&tmp), (void *)v, 1);
return *this;
}
2005-02-18 23:08:15 +01:00
//stubs for linkage
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 06:34:59 +01:00
NRVec<int> & NRVec<int>::normalize() {laerror("normalize() impossible for integer types"); return *this;}
2005-11-20 14:46:00 +01:00
template<>
2005-09-05 00:12:09 +02:00
NRVec<short> & NRVec<short>::normalize() {laerror("normalize() impossible for integer types"); return *this;}
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 06:34:59 +01:00
NRVec<char> & NRVec<char>::normalize() {laerror("normalize() impossible for integer types"); return *this;}
2005-11-20 14:46:00 +01:00
template<>
2005-02-18 23:08:15 +01:00
void NRVec<int>::gemv(const int beta,
const NRSMat<int> &A, const char trans,
const int alpha, const NRVec &x)
{
laerror("not yet implemented");
}
2005-11-20 14:46:00 +01:00
template<>
2005-09-05 00:12:09 +02:00
void NRVec<short>::gemv(const short beta,
const NRSMat<short> &A, const char trans,
const short alpha, const NRVec &x)
{
laerror("not yet implemented");
}
2005-11-20 14:46:00 +01:00
template<>
2005-02-18 23:08:15 +01:00
void NRVec<char>::gemv(const char beta,
const NRSMat<char> &A, const char trans,
const char alpha, const NRVec &x)
{
laerror("not yet implemented");
}
2005-11-20 14:46:00 +01:00
template<>
2005-02-18 23:08:15 +01:00
void NRVec<int>::gemv(const int beta,
const NRMat<int> &A, const char trans,
const int alpha, const NRVec &x)
{
laerror("not yet implemented");
}
2005-11-20 14:46:00 +01:00
template<>
2005-09-05 00:12:09 +02:00
void NRVec<short>::gemv(const short beta,
const NRMat<short> &A, const char trans,
const short alpha, const NRVec &x)
{
laerror("not yet implemented");
}
2005-11-20 14:46:00 +01:00
template<>
2005-02-18 23:08:15 +01:00
void NRVec<char>::gemv(const char beta,
const NRMat<char> &A, const char trans,
const char alpha, const NRVec &x)
{
laerror("not yet implemented");
}
2005-11-20 14:46:00 +01:00
template<>
2005-02-18 23:08:15 +01:00
void NRVec<int>::gemv(const int beta,
const SparseMat<int> &A, const char trans,
const int alpha, const NRVec &x)
{
laerror("not yet implemented");
}
2005-11-20 14:46:00 +01:00
template<>
2005-09-05 00:12:09 +02:00
void NRVec<short>::gemv(const short beta,
const SparseMat<short> &A, const char trans,
const short alpha, const NRVec &x)
{
laerror("not yet implemented");
}
2005-11-20 14:46:00 +01:00
template<>
2005-02-18 23:08:15 +01:00
void NRVec<char>::gemv(const char beta,
const SparseMat<char> &A, const char trans,
const char alpha, const NRVec &x)
{
laerror("not yet implemented");
}
2004-03-17 06:34:59 +01:00
2005-02-18 23:08:15 +01:00
// gemv calls
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRVec<double>::gemv(const double beta, const NRMat<double> &A,
const char trans, const double alpha, const NRVec &x)
{
#ifdef DEBUG
if ((trans == 'n'?A.ncols():A.nrows()) != x.size())
laerror("incompatible sizes in gemv A*x");
#endif
cblas_dgemv(CblasRowMajor, (trans=='n' ? CblasNoTrans:CblasTrans),
2005-02-18 23:08:15 +01:00
A.nrows(), A.ncols(), alpha, A, A.ncols(), x.v, 1, beta, v, 1);
2004-03-17 04:07:21 +01:00
}
2005-02-18 23:08:15 +01:00
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
void NRVec< complex<double> >::gemv(const complex<double> beta,
const NRMat< complex<double> > &A, const char trans,
const complex<double> alpha, const NRVec &x)
{
#ifdef DEBUG
if ((trans == 'n'?A.ncols():A.nrows()) != x.size())
laerror("incompatible sizes in gemv A*x");
#endif
cblas_zgemv(CblasRowMajor, (trans=='n' ? CblasNoTrans:CblasTrans),
2005-02-18 23:08:15 +01:00
A.nrows(), A.ncols(), &alpha, A, A.ncols(),
x.v, 1, &beta, v, 1);
2004-03-17 04:07:21 +01:00
}
2005-02-18 23:08:15 +01:00
2005-11-20 14:46:00 +01:00
template<>
2005-02-18 23:08:15 +01:00
void NRVec<double>::gemv(const double beta, const NRSMat<double> &A,
const char trans, const double alpha, const NRVec &x)
2004-03-17 04:07:21 +01:00
{
#ifdef DEBUG
2005-02-18 23:08:15 +01:00
if (A.ncols()!=x.size()) laerror("incompatible dimension in gemv A*x");
2004-03-17 04:07:21 +01:00
#endif
2005-02-18 23:08:15 +01:00
NRVec<double> result(nn);
cblas_dspmv(CblasRowMajor, CblasLower, A.ncols(), alpha, A, x.v, 1, beta, v, 1);
2004-03-17 04:07:21 +01:00
}
2005-02-18 23:08:15 +01:00
2005-11-20 14:46:00 +01:00
template<>
2005-02-18 23:08:15 +01:00
void NRVec< complex<double> >::gemv(const complex<double> beta,
const NRSMat< complex<double> > &A, const char trans,
const complex<double> alpha, const NRVec &x)
2004-03-17 04:07:21 +01:00
{
#ifdef DEBUG
2005-02-18 23:08:15 +01:00
if (A.ncols()!=x.size()) laerror("incompatible dimension in gemv");
2004-03-17 04:07:21 +01:00
#endif
2005-02-18 23:08:15 +01:00
NRVec< complex<double> > result(nn);
cblas_zhpmv(CblasRowMajor, CblasLower, A.ncols(), &alpha, A,
x.v, 1, &beta, v, 1);
2004-03-17 04:07:21 +01:00
}
2005-02-18 23:08:15 +01:00
2004-03-17 04:07:21 +01:00
// Direc product Mat = Vec | Vec
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
const NRMat<double> NRVec<double>::operator|(const NRVec<double> &b) const
{
NRMat<double> result(0.,nn,b.nn);
cblas_dger(CblasRowMajor, nn, b.nn, 1., v, 1, b.v, 1, result, b.nn);
return result;
}
2005-11-20 14:46:00 +01:00
template<>
2004-03-17 04:07:21 +01:00
const NRMat< complex<double> >
NRVec< complex<double> >::operator|(const NRVec< complex<double> > &b) const
{
NRMat< complex<double> > result(0.,nn,b.nn);
cblas_zgerc(CblasRowMajor, nn, b.nn, &CONE, v, 1, b.v, 1, result, b.nn);
return result;
}