diff --git a/fourindex.h b/fourindex.h index 1fc58f8..2a1f65c 100644 --- a/fourindex.h +++ b/fourindex.h @@ -103,15 +103,15 @@ public: piterator(matel4 *pp): symmetry(nosymmetry),p(pp),permindex(0){}; ~piterator() {}; piterator(const fourindex &x): symmetry(x.symmetry),p(x.list),permindex(0) {setup();}; - piterator& operator++() {if(++permindex==fourindex_permnumbers[symmetry]) {permindex=0; p=p->next;} setup(); return *this;} + piterator& operator++() {if(++permindex>=fourindex_permnumbers[symmetry]) {permindex=0; p=p->next;} setup(); return *this;} const matel4 & operator*() const {return my;} const matel4 * operator->() const {return &my;} piterator operator++(int) {laerror("postincrement not possible on permute-iterator");} - bool operator==(const piterator &rhs) const {return p==rhs.p && permindex==rhs.permindex && symmetry==rhs.symmetry;} - bool operator!=(const piterator &rhs) const {return p!=rhs.p || permindex!=rhs.permindex || symmetry!=rhs.symmetry;} + bool operator==(const piterator &rhs) const {return p==rhs.p && (!p || permindex==rhs.permindex);} + bool operator!=(const piterator &rhs) const {return p!=rhs.p || p && rhs.p && permindex!=rhs.permindex;} }; - piterator pbegin() const {return *this;} - piterator pend() const {return NULL;} + piterator pbegin() const {return piterator(*this);} + piterator pend() const {return piterator(NULL);} //constructors etc. inline fourindex() :nn(0),count(NULL),list(NULL) {}; diff --git a/la_traits.h b/la_traits.h index 0c221c5..77bb29b 100644 --- a/la_traits.h +++ b/la_traits.h @@ -39,14 +39,12 @@ typedef class scalar_true {}; //default is non-scalar template -class isscalar { - typedef scalar_false scalar_type; - }; +class isscalar { public: typedef scalar_false scalar_type;}; //specializations #define SCALAR(X) \ template<>\ -class isscalar {typedef scalar_true scalar_type;}; +class isscalar {public: typedef scalar_true scalar_type;}; //declare what is scalar SCALAR(char) diff --git a/mat.cc b/mat.cc index fa6d3e8..e6469a4 100644 --- a/mat.cc +++ b/mat.cc @@ -3,6 +3,7 @@ #include #include #include +#include extern "C" { extern ssize_t read(int, void *, size_t); extern ssize_t write(int, const void *, size_t); @@ -12,11 +13,11 @@ extern ssize_t write(int, const void *, size_t); ////////////////////////////////////////////////////////////////////////////// //// forced instantization in the corresponding object file -template NRMat; -template NRMat< complex >; -template NRMat; -template NRMat; -template NRMat; +template class NRMat; +template class NRMat< complex >; +template class NRMat; +template class NRMat; +template class NRMat; /* @@ -288,12 +289,16 @@ void NRMat::fscanf(FILE *f, const char *format) */ // Mat *= a +template<> NRMat & NRMat::operator*=(const double &a) { copyonwrite(); cblas_dscal(nn*mm, a, *this, 1); return *this; } + + +template<> NRMat< complex > & NRMat< complex >::operator*=(const complex &a) { @@ -302,6 +307,7 @@ NRMat< complex >::operator*=(const complex &a) return *this; } + //and for general type template NRMat & NRMat::operator*=(const T &a) @@ -318,6 +324,7 @@ NRMat & NRMat::operator*=(const T &a) // Mat += Mat +template<> NRMat & NRMat::operator+=(const NRMat &rhs) { #ifdef DEBUG @@ -328,6 +335,9 @@ NRMat & NRMat::operator+=(const NRMat &rhs) cblas_daxpy(nn*mm, 1.0, rhs, 1, *this, 1); return *this; } + + +template<> NRMat< complex > & NRMat< complex >::operator+=(const NRMat< complex > &rhs) { @@ -340,6 +350,8 @@ NRMat< complex >::operator+=(const NRMat< complex > &rhs) return *this; } + + //and for general type template NRMat & NRMat::operator+=(const NRMat &rhs) @@ -359,6 +371,7 @@ NRMat & NRMat::operator+=(const NRMat &rhs) // Mat -= Mat +template<> NRMat & NRMat::operator-=(const NRMat &rhs) { #ifdef DEBUG @@ -369,6 +382,10 @@ NRMat & NRMat::operator-=(const NRMat &rhs) cblas_daxpy(nn*mm, -1.0, rhs, 1, *this, 1); return *this; } + + + +template<> NRMat< complex > & NRMat< complex >::operator-=(const NRMat< complex > &rhs) { @@ -381,6 +398,8 @@ NRMat< complex >::operator-=(const NRMat< complex > &rhs) return *this; } + + //and for general type template NRMat & NRMat::operator-=(const NRMat &rhs) @@ -400,6 +419,7 @@ NRMat & NRMat::operator-=(const NRMat &rhs) // Mat += SMat +template<> NRMat & NRMat::operator+=(const NRSMat &rhs) { #ifdef DEBUG @@ -418,6 +438,10 @@ NRMat & NRMat::operator+=(const NRSMat &rhs) } return *this; } + + + +template<> NRMat< complex > & NRMat< complex >::operator+=(const NRSMat< complex > &rhs) { @@ -438,6 +462,9 @@ NRMat< complex >::operator+=(const NRSMat< complex > &rhs) return *this; } + + + //and for general type template NRMat & NRMat::operator+=(const NRSMat &rhs) @@ -461,6 +488,7 @@ NRMat & NRMat::operator+=(const NRSMat &rhs) // Mat -= SMat +template<> NRMat & NRMat::operator-=(const NRSMat &rhs) { #ifdef DEBUG @@ -479,6 +507,10 @@ NRMat & NRMat::operator-=(const NRSMat &rhs) } return *this; } + + + +template<> NRMat< complex > & NRMat< complex >::operator-=(const NRSMat< complex > &rhs) { @@ -521,7 +553,11 @@ NRMat & NRMat::operator-=(const NRSMat &rhs) return *this; } + + + // Mat.Mat - scalar product +template<> const double NRMat::dot(const NRMat &rhs) const { #ifdef DEBUG @@ -529,6 +565,10 @@ const double NRMat::dot(const NRMat &rhs) const #endif return cblas_ddot(nn*mm, (*this)[0], 1, rhs[0], 1); } + + + +template<> const complex NRMat< complex >::dot(const NRMat< complex > &rhs) const { @@ -542,6 +582,7 @@ NRMat< complex >::dot(const NRMat< complex > &rhs) const } // Mat * Mat +template<> const NRMat NRMat::operator*(const NRMat &rhs) const { #ifdef DEBUG @@ -552,6 +593,10 @@ const NRMat NRMat::operator*(const NRMat &rhs) const *this, mm, rhs, rhs.mm, 0.0, result, rhs.mm); return result; } + + + +template<> const NRMat< complex > NRMat< complex >::operator*(const NRMat< complex > &rhs) const { @@ -565,7 +610,9 @@ NRMat< complex >::operator*(const NRMat< complex > &rhs) const return result; } + // Multiply by diagonal from L +template<> void NRMat::diagmultl(const NRVec &rhs) { #ifdef DEBUG @@ -574,6 +621,10 @@ void NRMat::diagmultl(const NRVec &rhs) copyonwrite(); for(int i=0; i void NRMat< complex >::diagmultl(const NRVec< complex > &rhs) { #ifdef DEBUG @@ -583,7 +634,11 @@ void NRMat< complex >::diagmultl(const NRVec< complex > &rhs) for (int i=0; i void NRMat::diagmultr(const NRVec &rhs) { #ifdef DEBUG @@ -592,6 +647,10 @@ void NRMat::diagmultr(const NRVec &rhs) copyonwrite(); for (int i=0; i void NRMat< complex >::diagmultr(const NRVec< complex > &rhs) { #ifdef DEBUG @@ -601,7 +660,11 @@ void NRMat< complex >::diagmultr(const NRVec< complex > &rhs) for (int i=0; i const NRMat NRMat::operator*(const NRSMat &rhs) const { @@ -614,6 +677,10 @@ NRMat::operator*(const NRSMat &rhs) const (*this)[i], 1, 0.0, result[i], 1); return result; } + + + +template<> const NRMat< complex > NRMat< complex >::operator*(const NRSMat< complex > &rhs) const { @@ -629,6 +696,7 @@ NRMat< complex >::operator*(const NRSMat< complex > &rhs) const // sum of rows +template<> const NRVec NRMat::rsum() const { NRVec result(mm); @@ -637,6 +705,7 @@ const NRVec NRMat::rsum() const } // sum of columns +template<> const NRVec NRMat::csum() const { NRVec result(nn); @@ -645,8 +714,10 @@ const NRVec NRMat::csum() const } // complex conjugate of Mat +template<> NRMat &NRMat::conjugateme() {return *this;} +template<> NRMat< complex > & NRMat< complex >::conjugateme() { copyonwrite(); @@ -655,12 +726,14 @@ NRMat< complex > & NRMat< complex >::conjugateme() } // transpose and optionally conjugate +template<> const NRMat NRMat::transpose(bool conj) const { NRMat result(mm,nn); for(int i=0; i const NRMat< complex > NRMat< complex >::transpose(bool conj) const { @@ -672,6 +745,7 @@ NRMat< complex >::transpose(bool conj) const } // gemm : this = alpha*op( A )*op( B ) + beta*this +template<> void NRMat::gemm(const double &beta, const NRMat &a, const char transa, const NRMat &b, const char transb, const double &alpha) @@ -691,6 +765,10 @@ void NRMat::gemm(const double &beta, const NRMat &a, (transb=='n' ? CblasNoTrans : CblasTrans), nn, mm, k, alpha, a, a.mm, b , b.mm, beta, *this , mm); } + + + +template<> void NRMat< complex >::gemm(const complex & beta, const NRMat< complex > & a, const char transa, const NRMat< complex > & b, const char transb, @@ -714,6 +792,7 @@ void NRMat< complex >::gemm(const complex & beta, } // norm of Mat +template<> const double NRMat::norm(const double scalar) const { if (!scalar) return cblas_dnrm2(nn*mm, (*this)[0], 1); @@ -731,6 +810,10 @@ const double NRMat::norm(const double scalar) const } return sqrt(sum); } + + + +template<> const double NRMat< complex >::norm(const complex scalar) const { if (scalar == CZERO) return cblas_dznrm2(nn*mm, (*this)[0], 1); @@ -749,7 +832,11 @@ const double NRMat< complex >::norm(const complex scalar) const return sqrt(sum); } + + + // axpy: this = a * Mat +template<> void NRMat::axpy(const double alpha, const NRMat &mat) { #ifdef DEBUG @@ -758,6 +845,10 @@ void NRMat::axpy(const double alpha, const NRMat &mat) copyonwrite(); cblas_daxpy(nn*mm, alpha, mat, 1, *this, 1); } + + + +template<> void NRMat< complex >::axpy(const complex alpha, const NRMat< complex > & mat) { @@ -768,14 +859,24 @@ void NRMat< complex >::axpy(const complex alpha, cblas_zaxpy(nn*mm, (void *)&alpha, mat, 1, (void *)(*this)[0], 1); } + + + // trace of Mat +template<> const double NRMat::trace() const { #ifdef DEBUG if (nn != mm) laerror("no-square matrix in Mat::trace()"); #endif return cblas_dasum(nn, (*this)[0], nn+1); + + } + + + +template<> const complex NRMat< complex >::trace() const { #ifdef DEBUG @@ -795,6 +896,7 @@ const complex NRMat< complex >::trace() const //get diagonal; for compatibility with large matrices do not return newly created object //for non-square get diagonal of A^TA, will be used as preconditioner +template<> void NRMat::diagonalof(NRVec &r, const bool divide) const { #ifdef DEBUG diff --git a/smat.cc b/smat.cc index fe13588..ecfc7f7 100644 --- a/smat.cc +++ b/smat.cc @@ -3,6 +3,7 @@ #include #include #include +#include extern "C" { extern ssize_t read(int, void *, size_t); extern ssize_t write(int, const void *, size_t); @@ -13,11 +14,11 @@ extern ssize_t write(int, const void *, size_t); ////////////////////////////////////////////////////////////////////////////// ////// forced instantization in the corresponding object file -template NRSMat; -template NRSMat< complex >; -template NRSMat; -template NRSMat; -template NRSMat; +template class NRSMat; +template class NRSMat< complex >; +template class NRSMat; +template class NRSMat; +template class NRSMat; @@ -146,6 +147,7 @@ void NRSMat::fscanf(FILE *f, const char *format) */ // SMat * Mat +template<> const NRMat NRSMat::operator*(const NRMat &rhs) const { #ifdef DEBUG @@ -157,6 +159,9 @@ const NRMat NRSMat::operator*(const NRMat &rhs) const 0.0, result[0]+k, rhs.ncols()); return result; } + + +template<> const NRMat< complex > NRSMat< complex >::operator*(const NRMat< complex > &rhs) const { @@ -170,7 +175,9 @@ NRSMat< complex >::operator*(const NRMat< complex > &rhs) const return result; } + // SMat * SMat +template<> const NRMat NRSMat::operator*(const NRSMat &rhs) const { #ifdef DEBUG @@ -218,6 +225,10 @@ const NRMat NRSMat::operator*(const NRSMat &rhs) const return result; } + + + +template<> const NRMat< complex > NRSMat< complex >::operator*(const NRSMat< complex > &rhs) const { @@ -230,7 +241,12 @@ NRSMat< complex >::operator*(const NRSMat< complex > &rhs) const return result; // laerror("complex SMat*Smat not implemented"); } + + + + // S dot S +template<> const double NRSMat::dot(const NRSMat &rhs) const { #ifdef DEBUG @@ -238,6 +254,10 @@ const double NRSMat::dot(const NRSMat &rhs) const #endif return cblas_ddot(NN2, v, 1, rhs.v, 1); } + + + +template<> const complex NRSMat< complex >::dot(const NRSMat< complex > &rhs) const { @@ -251,6 +271,7 @@ NRSMat< complex >::dot(const NRSMat< complex > &rhs) const // norm of the matrix +template<> const double NRSMat::norm(const double scalar) const { if (!scalar) return cblas_dnrm2(NN2, v, 1); @@ -265,8 +286,11 @@ const double NRSMat::norm(const double scalar) const } return sqrt(sum); } -const double -NRSMat< complex >::norm(const complex scalar) const + + + +template<> +const double NRSMat< complex >::norm(const complex scalar) const { if (!(scalar.real()) && !(scalar.imag())) return cblas_dznrm2(NN2, (void *)v, 1); @@ -282,7 +306,12 @@ NRSMat< complex >::norm(const complex scalar) const return sqrt(sum); } + + + + // axpy: S = S * a +template<> void NRSMat::axpy(const double alpha, const NRSMat & x) { #ifdef DEBUG @@ -291,6 +320,10 @@ void NRSMat::axpy(const double alpha, const NRSMat & x) copyonwrite(); cblas_daxpy(NN2, alpha, x.v, 1, v, 1); } + + + +template<> void NRSMat< complex >::axpy(const complex alpha, const NRSMat< complex > & x) { diff --git a/sparsemat.cc b/sparsemat.cc index 6ed5d6f..8f7c18f 100644 --- a/sparsemat.cc +++ b/sparsemat.cc @@ -4,12 +4,13 @@ #include #include #include +#include #include "sparsemat.h" ////////////////////////////////////////////////////////////////////////////// //// forced instantization in the corresponding object file -template SparseMat; -template SparseMat >; +template class SparseMat; +template class SparseMat >; @@ -790,6 +791,7 @@ while(l) return *this; } +template<> const double SparseMat::dot(const NRMat &rhs) const { double r=0; @@ -803,12 +805,14 @@ while(l) return r; } +template<> void NRMat >::gemm(const complex &beta, const SparseMat > &a, const char transa, const NRMat > &b, const char transb, const complex &alpha) { laerror("not implemented yet"); } +template<> void NRMat::gemm(const double &beta, const SparseMat &a, const char transa, const NRMat &b, const char transb, const double &alpha) { bool transpa = tolower(transa)!='n'; //not OK for complex diff --git a/vec.cc b/vec.cc index c115c04..80c824c 100644 --- a/vec.cc +++ b/vec.cc @@ -29,11 +29,11 @@ INSTANTIZE(short) INSTANTIZE(unsigned short) INSTANTIZE(char) INSTANTIZE(unsigned char) -template NRVec; -template NRVec >; -template NRVec; -template NRVec; -template NRVec; +template class NRVec; +template class NRVec >; +template class NRVec; +template class NRVec; +template class NRVec;