*** empty log message ***

This commit is contained in:
jiri
2010-01-11 10:12:28 +00:00
parent 12c88e6872
commit 8ec7c11a6e
6 changed files with 92 additions and 42 deletions

View File

@@ -30,11 +30,11 @@ namespace LA {
//dense times sparse (not necessarily symmetric)
template <typename T>
SparseSMat<T> & NRMat<T>::operator*(const SparseSMat<T> &rhs) const
SparseSMat<T> NRMat<T>::operator*(const SparseSMat<T> &rhs) const
{
SparseSMat<T> r(nn,rhs.ncols());
if(mm!=rhs.nrows()) laerror("incompatible sizes in NRMat*SparseSMat");
for(SPMatindex k=0; k<nn; ++k) //summation loop
for(SPMatindex k=0; k<mm; ++k) //summation loop
{
std::map<SPMatindex,T> * kl = rhs.line(k);
if(kl)
@@ -71,6 +71,7 @@ void SparseSMat<T>::gemm(const T beta, const SparseSMat &a, const char transa, c
{
(*this) *= beta;
if(alpha==(T)0) return;
if(a.nn!=a.mm || b.nn!=b.mm || nn!=mm) laerror("SparseSMat::gemm implemented only for square symmetric matrices");
if(a.nn!=b.nn || a.nn!=nn) laerror("incompatible sizes in SparseSMat::gemm");
copyonwrite();
@@ -89,13 +90,10 @@ for(SPMatindex k=0; k<nn; ++k) //summation loop
for(p=a.v[k]->begin(), i=0; p!=a.v[k]->end(); ++p,++i) { ai[i] = p->first; av[i] = LA_traits<T>::conjugate(p->second); }
else
for(p=a.v[k]->begin(), i=0; p!=a.v[k]->end(); ++p,++i) { ai[i] = p->first; av[i] = p->second; }
if(tolower(transb)=='c')
for(p=b.v[k]->begin(), i=0; p!=b.v[k]->end(); ++p,++i) { bi[i] = p->first; bv[i] = LA_traits<T>::conjugate(p->second); }
else
for(p=b.v[k]->begin(), i=0; p!=b.v[k]->end(); ++p,++i) { bi[i] = p->first; bv[i] = p->second; }
for(p=b.v[k]->begin(), i=0; p!=b.v[k]->end(); ++p,++i) { bi[i] = p->first; bv[i] = p->second; }
//make multiply via blas
NRMat<T> prod=av.otimes(bv,false,alpha);
NRMat<T> prod=av.otimes(bv,tolower(transb)=='c',alpha);
//scatter the results -- probably the computational bottleneck
for(i=0; i<prod.nrows(); ++i) for(j=0; j<prod.ncols(); ++j)
@@ -127,7 +125,7 @@ return *this;
template <class T>
void SparseSMat<T>::axpy(const T alpha, const SparseSMat &x, const bool transp)
{
if(nn!=x.nn) laerror("incompatible matrix dimensions in SparseSMat::axpy");
if(nn!=x.nn || mm!=x.mm) laerror("incompatible matrix dimensions in SparseSMat::axpy");
if(alpha==(T)0) return;
copyonwrite();
for(SPMatindex i=0; i<nn; ++i) if(x.v[i])
@@ -148,7 +146,8 @@ simplify();
template <class T>
void SparseSMat<T>::gemv(const T beta, NRVec<T> &r, const char trans, const T alpha, const NRVec<T> &x) const
{
if(nn!=r.size() || nn!= x.size()) laerror("incompatible matrix vector dimensions in SparseSMat::gemv");
if(nn!=r.size() || mm!= x.size()) laerror("incompatible matrix vector dimensions in SparseSMat::gemv");
if(trans) laerror("transposition not implemented yet in SparseSMat::gemv");
r *= beta;
if(alpha == (T)0) return;
r.copyonwrite();
@@ -236,6 +235,7 @@ return std::sqrt(sum);
template <class T>
const T* SparseSMat<T>::diagonalof(NRVec<T> &r, const bool divide, bool cache) const
{
if(nn!=mm) laerror("non-square matrix in SparseSMat::diagonalof");
if(nn!=r.size()) laerror("incompatible vector size in diagonalof()");
NRVec<T> *rr;
@@ -266,8 +266,7 @@ void SparseSMat<T>::get(int fd, bool dimen, bool transp) {
if(dimen) {
if(2*sizeof(SPMatindex)!=read(fd,&dim,2*sizeof(SPMatindex))) laerror("read() error in SparseSMat::get ");
if(dim[0]!=dim[1]) laerror("SparseSMat must be square (nonsquare read in ::get)");
resize(dim[0]);
resize(dim[0],dim[1]);
}
else copyonwrite();
@@ -288,7 +287,7 @@ void SparseSMat<T>::put(int fd, bool dimen, bool transp) const {
errno=0;
if(dimen) {
if(sizeof(SPMatindex)!=write(fd,&nn,sizeof(SPMatindex))) laerror("cannot write() 1 in SparseSMat::put");
if(sizeof(SPMatindex)!=write(fd,&nn,sizeof(SPMatindex))) laerror("cannot write() 2 in SparseSMat::put");
if(sizeof(SPMatindex)!=write(fd,&mm,sizeof(SPMatindex))) laerror("cannot write() 2 in SparseSMat::put");
}
typename SparseSMat<T>::iterator p(*this);