diff --git a/sparsesmat.h b/sparsesmat.h index 4c24028..8a8040f 100644 --- a/sparsesmat.h +++ b/sparsesmat.h @@ -30,6 +30,7 @@ #include "sparsemat.h" #include "vec.h" #include "mat.h" +#include "smat.h" #include #include @@ -52,11 +53,13 @@ public: SparseSMat(const SPMatindex n); SparseSMat(const SparseSMat &rhs); SparseSMat(const SparseMat &rhs); + SparseSMat(const NRSMat &rhs); SparseSMat & operator=(const SparseSMat &rhs); void copyonwrite(); - void clear(); + void resize(const SPMatindex n); + void clear() {resize(nn);} void simplify(); - ~SparseSMat() {clear();}; + ~SparseSMat(); // SparseSMat & operator=(const T a); //assign a to diagonal SparseSMat & operator+=(const T a); //assign a to diagonal @@ -65,11 +68,72 @@ public: SparseSMat & operator+=(const SparseSMat &rhs); SparseSMat & operator-=(const SparseSMat &rhs); void gemv(const T beta, NRVec &r, const char trans, const T alpha, const NRVec &x) const; - void gemm(const T beta, const SparseSMat &a, const char transa, const SparseSMat &b, const char transb, const T alpha); - const SparseSMat operator*(const SparseSMat &rhs) const; + const SparseSMat operator*(const SparseSMat &rhs) const; //!!!NOT A GENERAL ROUTINE, JUST FOR THE CASES WHEN THE RESULT STAYS SYMMETRIC const typename LA_traits::normtype norm(const T scalar=(T)0) const; - void add(const SPMatindex n, const SPMatindex m, const T elem); + void add(const SPMatindex n, const SPMatindex m, const T elem, const bool both=true); unsigned int length() const; + int nrows() const {return nn;} + int ncols() const {return nn;} + + class iterator {//not efficient, just for output to ostreams + private: + matel *p; + matel my; + SPMatindex row; + typename std::map::iterator *col; + typename std::map::iterator mycol; + SPMatindex mynn; + std::map **myv; + + + public: + //compiler-generated iterator & operator=(const iterator &rhs); + //compiler-generated iterator(const iterator &rhs); + iterator(): p(NULL),row(0),col(NULL),mynn(0),myv(NULL) {}; + iterator(const SparseSMat &rhs) : mynn(rhs.nn), myv(rhs.v), col(NULL) {row=0; p= &my; operator++();} + iterator operator++() { + if(col) //finish column list + { + ++mycol; + if(mycol != myv[row]->end()) + { + p->row = row; + p->col = mycol->first; + p->elem = mycol->second; + return *this; + } + else + { + col=NULL; + ++row; if(row==mynn) {p=NULL; return *this;} //end() + } + } + nextrow: + while(myv[row]==NULL) {++row; if(row==mynn) {p=NULL; return *this;}} //end() + + //we are at next nonempty row + col = &mycol; + mycol = myv[row]->begin(); + if(mycol == myv[row]->end()) {col=NULL; + ++row; + if(row==mynn) {p=NULL; return *this;} else goto nextrow; + } + //first column of new row + p->row = row; + p->col = mycol->first; + p->elem = mycol->second; + return *this; + }; + iterator(matel *q) {p=q; col=NULL;}//just for end() + //compiler-generated ~iterator() {}; + bool operator!=(const iterator &rhs) const {return p!=rhs.p;} //only for comparison with end() + bool operator==(const iterator &rhs) const {return p==rhs.p;} //only for comparison with end() + matel & operator*() const {return *p;} + matel * operator->() const {return p;} + bool notend() const {return (bool)p;}; + }; + iterator begin() const {return iterator(*this);} + iterator end() const {return iterator(NULL);} }; template @@ -92,13 +156,13 @@ if(count) (*count)++; template -void SparseSMat::clear() +SparseSMat::~SparseSMat() { if(!count) return; if(--(*count) <= 0) { if(v) { - for(SPMatindex i=0; i::clear() } +template +void SparseSMat::resize(const SPMatindex n) +{ +if(!count) + { + if(n==0) return; + count = new int(1); + nn=n; + v= new std::map * [nn]; + for(SPMatindex i=0; i 1) //it was shared + { + (*count)--; + if(n) + { + count = new int(1); + nn=n; + v= new std::map * [nn]; + for(SPMatindex i=0; i SparseSMat & SparseSMat::operator=(const SparseSMat &rhs) @@ -150,7 +251,7 @@ void SparseSMat::copyonwrite() template -void SparseSMat::add(const SPMatindex n, const SPMatindex m, const T elem) +void SparseSMat::add(const SPMatindex n, const SPMatindex m, const T elem, const bool both) { #ifdef DEBUG if(n>=nn || m>=nn) laerror("illegal index in SparseSMat::add()"); @@ -162,7 +263,7 @@ typename std::map::iterator p; p= v[n]->find(m); if(p!=v[n]->end()) p->second+=elem; else (*v[n])[m] = elem; -if(n!=m) +if(n!=m && both) //add also transposed { p= v[m]->find(n); if(p!=v[m]->end()) p->second+=elem; else (*v[m])[n] = elem; @@ -188,5 +289,46 @@ for(SPMatindex i=0; i +std::ostream & operator<<(std::ostream &s, const SparseSMat &x) +{ + SPMatindex n; + + n = x.nrows(); + s << n << " "<< n<< std::endl; + +typename SparseSMat::iterator p(x); +for(; +p.notend(); +++p) +s << (int)p->row << ' ' << (int)p->col << ' ' << (typename LA_traits_io::IOtype) p->elem << '\n'; + + s << "-1 -1\n"; + return s; +} + +template +std::istream& operator>>(std::istream &s, SparseSMat &x) + { + SPMatindex n,m; + long i,j; + s >> n >> m; + if(n!=m) laerror("SparseSMat must be square"); + x.resize(n); + s >> i >> j; + typename LA_traits_io::IOtype tmp; + while(i>=0 && j>=0) + { + s>>tmp; + x.add(i,j,tmp,false); + s >> i >> j; + } + return s; + } + + + + + #endif //_SPARSESMAT_H_