From 91d130e0f7b1e17fda70e293c5d6bc431bb439b6 Mon Sep 17 00:00:00 2001 From: jiri Date: Wed, 11 Nov 2009 22:07:25 +0000 Subject: [PATCH] *** empty log message *** --- sparsesmat.cc | 101 ++++++++++++++++++++++++++++++++++++++++++++++++++ sparsesmat.h | 61 +++++++++++++++--------------- 2 files changed, 132 insertions(+), 30 deletions(-) create mode 100644 sparsesmat.cc diff --git a/sparsesmat.cc b/sparsesmat.cc new file mode 100644 index 0000000..5932ef2 --- /dev/null +++ b/sparsesmat.cc @@ -0,0 +1,101 @@ +/* + LA: linear algebra C++ interface library + Copyright (C) 2008 Jiri Pittner or + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +*/ + +#include +#include +#include +#include +#include +#include +#include +#include "sparsesmat.h" + +template +void SparseSMat::gemm(const T beta, const SparseSMat &a, const char transa, const SparseSMat &b, const char transb, const T alpha) +{ +(*this) *= beta; +if(alpha==(T)0) return; +if(a.nn!=b.nn || a.nn!=nn) laerror("incompatible sizes in SparseSMat::gemm"); +copyonwrite(); + +for(SPMatindex k=0; k av(a.v[k]->size()); + NRVec bv(b.v[k]->size()); + NRVec ai(a.v[k]->size()); + NRVec bi(b.v[k]->size()); + + //gather the data + typename std::map::iterator p; + int i,j; + for(p=a.v[k]->begin(), i=0; p!=a.v[k]->end(); ++p,++i) + { + ai[i] = p->first; + av[i] = p->second; + } + for(p=b.v[k]->begin(), i=0; p!=b.v[k]->end(); ++p,++i) + { + bi[i] = p->first; + bv[i] = p->second; + } + + //make multiply via blas + NRMat prod=av.otimes(bv,false,alpha); + + //scatter the results + for(i=0; i +SparseSMat & SparseSMat::operator*=(const T &a) +{ +if(!count) laerror("operator*= on undefined lhs"); +if(a==(T)1) return *this; +if(a==(T)0) {clear(); return *this;} +copyonwrite(); + +for(SPMatindex i=0; i::iterator p; + for(p=v[i]->begin(); p!=v[i]->end(); ++p) p->second *= a; + } + +return *this; +} + + + +#define INSTANTIZE(T) \ +template void SparseSMat::gemm(const T beta, const SparseSMat &a, const char transa, const SparseSMat &b, const char transb, const T alpha); \ +template SparseSMat & SparseSMat::operator*=(const T &a); \ + +INSTANTIZE(double) + +INSTANTIZE(complex) + +//// forced instantization of functions in the header in the corresponding object file +template class SparseSMat; +template class SparseSMat >; + + diff --git a/sparsesmat.h b/sparsesmat.h index 3282389..41ec3bd 100644 --- a/sparsesmat.h +++ b/sparsesmat.h @@ -62,24 +62,25 @@ public: void simplify(); ~SparseSMat(); inline int getcount() const {return count?*count:0;} - // + SparseSMat & operator*=(const T &a); //multiply by a scalar + inline const SparseSMat operator*(const T &rhs) const {return SparseSMat(*this) *= rhs;} +/*@@@to be done inline const SparseSMat operator+(const T &rhs) const {return SparseSMat(*this) += rhs;} inline const SparseSMat operator-(const T &rhs) const {return SparseSMat(*this) -= rhs;} - inline const SparseSMat operator*(const T &rhs) const {return SparseSMat(*this) *= rhs;} inline const SparseSMat operator+(const SparseSMat &rhs) const {return SparseSMat(*this) += rhs;} inline const SparseSMat operator-(const SparseSMat &rhs) const {return SparseSMat(*this) -= rhs;} - inline const SparseSMat operator*(const SparseSMat &rhs) const; //!!!NOT A GENERAL ROUTINE, JUST FOR THE CASES WHEN THE RESULT STAYS SYMMETRIC - SparseSMat & operator=(const T &a); //assign a to diagonal SparseSMat & operator+=(const T &a); //assign a to diagonal SparseSMat & operator-=(const T &a); //assign a to diagonal - SparseSMat & operator*=(const T &a); //multiply by a scalar SparseSMat & operator+=(const SparseSMat &rhs); SparseSMat & operator-=(const SparseSMat &rhs); void gemv(const T beta, NRVec &r, const char trans, const T alpha, const NRVec &x) const; void axpy(const T alpha, const SparseSMat &x, const bool transp=0); // this+= a*x const typename LA_traits::normtype norm(const T scalar=(T)0) const; - void add(const SPMatindex n, const SPMatindex m, const T elem, const bool both=true); +*/ + inline const SparseSMat operator*(const SparseSMat &rhs) const {SparseSMat r(nn); r.gemm(0,*this,'n',rhs,'n',1); return r;}; //!!!NOT A GENERAL ROUTINE, JUST FOR THE CASES WHEN THE RESULT STAYS SYMMETRIC + void gemm(const T beta, const SparseSMat &a, const char transa, const SparseSMat &b, const char transb, const T alpha); //this := alpha*op( A )*op( B ) + beta*this !!!NOT A GENERAL ROUTINE, JUST FOR THE CASES WHEN THE RESULT STAYS SYMMETRIC + inline void add(const SPMatindex n, const SPMatindex m, const T elem, const bool both=true); unsigned int length() const; void transposeme() const {}; int nrows() const {return nn;} @@ -293,7 +294,6 @@ void SparseSMat::add(const SPMatindex n, const SPMatindex m, const T elem, co if(n>=nn || m>=nn) laerror("illegal index in SparseSMat::add()"); #endif if(!v[n]) v[n] = new std::map; -if(!v[m]) v[m] = new std::map; typename std::map::iterator p; @@ -301,6 +301,7 @@ p= v[n]->find(m); if(p!=v[n]->end()) p->second+=elem; else (*v[n])[m] = elem; if(n!=m && both) //add also transposed { + if(!v[m]) v[m] = new std::map; p= v[m]->find(n); if(p!=v[m]->end()) p->second+=elem; else (*v[m])[n] = elem; } @@ -317,42 +318,42 @@ for(SPMatindex i=0; i l; typename std::map::iterator p; - for(p=v[i]->begin(); p!=v[i]->end(); ++p) - if(std::abs(p->second) < SPARSEEPSILON) l.push_front(p->first); - typename std::list::iterator q; - for(q=l.begin(); q!=l.end(); ++q) v[i]->erase(*q); - if(v[i]->size() == 0) delete v[i]; - } +for(p=v[i]->begin(); p!=v[i]->end(); ++p) + if(std::abs(p->second) < SPARSEEPSILON) l.push_front(p->first); +typename std::list::iterator q; +for(q=l.begin(); q!=l.end(); ++q) v[i]->erase(*q); +if(v[i]->size() == 0) delete v[i]; +} } template std::ostream & operator<<(std::ostream &s, const SparseSMat &x) { - SPMatindex n; +SPMatindex n; - n = x.nrows(); - s << n << " "<< n<< std::endl; +n = x.nrows(); +s << n << " "<< n<< std::endl; typename SparseSMat::iterator p(x); for(; p.notend(); ++p) s << (int)p->row << ' ' << (int)p->col << ' ' << (typename LA_traits_io::IOtype) p->elem << '\n'; - s << "-1 -1\n"; - return s; +s << "-1 -1\n"; +return s; } template std::istream& operator>>(std::istream &s, SparseSMat &x) - { - SPMatindex n,m; - long i,j; - s >> n >> m; - if(n!=m) laerror("SparseSMat must be square"); - x.resize(n); - s >> i >> j; - typename LA_traits_io::IOtype tmp; - while(i>=0 && j>=0) - { - s>>tmp; - x.add(i,j,tmp,false); + { + SPMatindex n,m; + long i,j; + s >> n >> m; + if(n!=m) laerror("SparseSMat must be square"); + x.resize(n); + s >> i >> j; + typename LA_traits_io::IOtype tmp; + while(i>=0 && j>=0) + { + s>>tmp; + x.add(i,j,tmp,false); s >> i >> j; } return s;