continueing on permutations

This commit is contained in:
2021-05-19 22:29:47 +02:00
parent 83b9463334
commit 78c94f1e17
9 changed files with 558 additions and 31 deletions

245
mat.cc
View File

@@ -27,6 +27,7 @@
#include <fcntl.h>
#include <errno.h>
#include <unistd.h>
#include <math.h>
namespace LA {
@@ -2754,6 +2755,7 @@ NRMat<T>& NRMat<T>::swap_rows(){
return *this;
}
/***************************************************************************//**
* interchange the order of the columns of the current (real) matrix
* @return reference to the modified matrix
@@ -2925,6 +2927,61 @@ NRMat<T>& NRMat<T>::swap_rows(const int a, const int b){
return *this;
}
/*rotate rows or columns of a matrix - general implementation, more efficient version could be done with BLAS scal and axpy operations
* but it would require allocation of temporary storage
*/
template<typename T>
NRMat<T>& NRMat<T>::rotate_rows(const int a, const int b, const T phi){
T tmp1,tmp2;
copyonwrite();
T c=cos(phi);
T s=sin(phi);
#ifdef CUDALA
if(location == cpu){
#endif
for(register int j=0;j<mm;j++){
tmp1 = (*this)(a,j);
tmp2 = (*this)(b,j);
(*this)(a,j) = c*tmp1 + s*tmp2;
(*this)(b,j) = c*tmp2 - s*tmp1;
}
#ifdef CUDALA
}else{
laerror("rotate_rows not implemented on gpu");
}
#endif
return *this;
}
template<typename T>
NRMat<T>& NRMat<T>::rotate_cols(const int a, const int b, const T phi){
T tmp1,tmp2;
copyonwrite();
T c=cos(phi);
T s=sin(phi);
#ifdef CUDALA
if(location == cpu){
#endif
for(register int j=0;j<nn;j++){
tmp1 = (*this)(j,a);
tmp2 = (*this)(j,b);
(*this)(j,a) = c*tmp1 + s*tmp2;
(*this)(j,b) = c*tmp2 - s*tmp1;
}
#ifdef CUDALA
}else{
laerror("rotate_rows not implemented on gpu");
}
#endif
return *this;
}
/***************************************************************************//**
* interchange the order of the rows and columns of the current
* real matrix \f$A\f$ of type T, i.e. perform the operation
@@ -3075,9 +3132,23 @@ NRMat<T>& NRMat<T>::swap_rows_cols(){
return *this;
}
//permutation matrix
template<typename T>
NRMat<T>::NRMat(const NRPerm<int> &p, const bool direction)
{
int n=p.size();
resize(n,n);
clear();
for(int i=0; i<n; ++i)
{
if(direction) (*this)(i,p[i+1]-1)=1;
else (*this)(p[i+1]-1,i)=1;
}
}
//apply permutations
template<typename T>
const NRMat<T> NRMat<T>::permute_rows(const NRPerm<int> &p) const
const NRMat<T> NRMat<T>::permuted_rows(const NRPerm<int> &p, const bool inverse) const
{
#ifdef DEBUG
if(!p.is_valid()) laerror("invalid permutation of matrix");
@@ -3088,12 +3159,13 @@ if(n!=nn) laerror("incompatible permutation and matrix");
if(this->getlocation() != cpu || p.getlocation() != cpu ) laerror("permutations can be done only in CPU memory");
#endif
NRMat<T> r(nn,mm);
for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=0; j<mm; ++j) r(i-1,j) = (*this)(pi,j);}
if(inverse) for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=0; j<mm; ++j) r(i-1,j) = (*this)(pi,j);}
else for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=0; j<mm; ++j) r(pi,j) = (*this)(i-1,j);}
return r;
}
template<typename T>
const NRMat<T> NRMat<T>::permute_cols(const NRPerm<int> &p) const
const NRMat<T> NRMat<T>::permuted_cols(const NRPerm<int> &p, const bool inverse) const
{
#ifdef DEBUG
if(!p.is_valid()) laerror("invalid permutation of matrix");
@@ -3104,12 +3176,13 @@ if(n!=mm) laerror("incompatible permutation and matrix");
if(this->getlocation() != cpu || p.getlocation() != cpu ) laerror("permutations can be done only in CPU memory");
#endif
NRMat<T> r(nn,mm);
for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=0; j<nn; ++j) r(j,i-1) = (*this)(j,pi);}
if(inverse) for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=0; j<nn; ++j) r(j,i-1) = (*this)(j,pi);}
else for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=0; j<nn; ++j) r(j,pi) = (*this)(j,i-1);}
return r;
}
template<typename T>
const NRMat<T> NRMat<T>::permute_both(const NRPerm<int> &p, const NRPerm<int> &q) const
const NRMat<T> NRMat<T>::permuted_both(const NRPerm<int> &p, const NRPerm<int> &q, const bool inverse) const
{
#ifdef DEBUG
if(!p.is_valid() || !q.is_valid() ) laerror("invalid permutation of matrix");
@@ -3121,11 +3194,171 @@ if(n!=nn ||m!=mm) laerror("incompatible permutation and matrix");
if(this->getlocation() != cpu || p.getlocation() != cpu ) laerror("permutations can be done only in CPU memory");
#endif
NRMat<T> r(nn,mm);
for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=1; j<=m; ++j) r(i-1,j-1) = (*this)(pi,q[j]-1);}
if(inverse) for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=1; j<=m; ++j) r(i-1,j-1) = (*this)(pi,q[j]-1);}
else for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=1; j<=m; ++j) r(pi,q[j]-1) = (*this)(i-1,j-1);}
return r;
}
template<typename T>
void NRMat<T>::permuteme_rows(const CyclePerm<int> &p)
{
#ifdef DEBUG
if(!p.is_valid()) laerror("invalid permutation of matrix");
#endif
if(p.max()>nn) laerror("incompatible permutation and matrix");
#ifdef CUDALA
if(this->getlocation() != cpu || p.getlocation() != cpu ) laerror("permutations can be done only in CPU memory");
#endif
copyonwrite();
T *tmp = new T[mm];
for(int cycle=1; cycle<=p.size(); ++cycle)
{
int length= p[cycle].size();
if(length<=1) continue; //trivial cycle
for(int j=0; j<mm; ++j) tmp[j] = (*this)(p[cycle][length]-1,j);
for(int i=length; i>1; --i)
for(int j=0; j<mm; ++j) (*this)(p[cycle][i]-1,j)=(*this)(p[cycle][i-1]-1,j);
for(int j=0; j<mm; ++j) (*this)(p[cycle][1]-1,j)=tmp[j];
}
delete[] tmp;
}
template<typename T>
void NRMat<T>::permuteme_cols(const CyclePerm<int> &p)
{
#ifdef DEBUG
if(!p.is_valid()) laerror("invalid permutation of matrix");
#endif
if(p.max()>mm) laerror("incompatible permutation and matrix");
#ifdef CUDALA
if(this->getlocation() != cpu || p.getlocation() != cpu ) laerror("permutations can be done only in CPU memory");
#endif
copyonwrite();
T *tmp = new T[nn];
for(int cycle=1; cycle<=p.size(); ++cycle)
{
int length= p[cycle].size();
if(length<=1) continue; //trivial cycle
for(int j=0; j<nn; ++j) tmp[j] = (*this)(j,p[cycle][length]-1);
for(int i=length; i>1; --i)
for(int j=0; j<nn; ++j) (*this)(j,p[cycle][i]-1)=(*this)(j,p[cycle][i-1]-1);
for(int j=0; j<nn; ++j) (*this)(j,p[cycle][1]-1)=tmp[j];
}
delete[] tmp;
}
//double and complex specialization
template<>
void NRMat<double>::scale_row(const int i, const double f)
{
#ifdef DEBUG
if(i<0||i>=nn) laerror("index out of range in scale_row");
#endif
copyonwrite();
#ifdef CUDALA
if(location == cpu) {
#endif
cblas_dscal(mm, f, &(*this)(i,0), 1);
#ifdef CUDALA
}else{
cublasDscal(mm, f, v+i*mm, 1);
TEST_CUBLAS("cublasDscal");
}
#endif
}
template<>
void NRMat<double>::scale_col(const int i, const double f)
{
#ifdef DEBUG
if(i<0||i>=mm) laerror("index out of range in scale_col");
#endif
copyonwrite();
#ifdef CUDALA
if(location == cpu) {
#endif
cblas_dscal(nn, f, &(*this)(0,i), mm);
#ifdef CUDALA
}else{
cublasDscal(nn, f, v+i, mm);
TEST_CUBLAS("cublasDscal");
}
#endif
}
template<>
void NRMat<std::complex<double> >::scale_row(const int i, const std::complex<double> f)
{
#ifdef DEBUG
if(i<0||i>=nn) laerror("index out of range in scale_row");
#endif
copyonwrite();
#ifdef CUDALA
if(location == cpu) {
#endif
cblas_zscal(mm, &f, &(*this)(i,0), 1);
#ifdef CUDALA
}else{
const cuDoubleComplex fac = *(reinterpret_cast<const cuDoubleComplex*> (&f));
cublasZscal(mm, &fac, v+i*mm, 1);
TEST_CUBLAS("cublasDscal");
}
#endif
}
template<>
void NRMat<std::complex<double> >::scale_col(const int i, const std::complex<double> f)
{
#ifdef DEBUG
if(i<0||i>=mm) laerror("index out of range in scale_col");
#endif
copyonwrite();
#ifdef CUDALA
if(location == cpu) {
#endif
cblas_zscal(nn, &f, &(*this)(0,i), mm);
#ifdef CUDALA
}else{
const cuDoubleComplex fac = *(reinterpret_cast<const cuDoubleComplex*> (&f));
cublasZscal(nn, &fac, v+i, mm);
TEST_CUBLAS("cublasDscal");
}
#endif
}
//general version
template<typename T>
void NRMat<T>::scale_row(const int i, const T f)
{
#ifdef DEBUG
if(i<0||i>=nn) laerror("index out of range in scale_row");
#endif
copyonwrite();
for(int j=0; j<mm; ++j) (*this)(i,j) *= f;
}
template<typename T>
void NRMat<T>::scale_col(const int i, const T f)
{
#ifdef DEBUG
if(i<0||i>=mm) laerror("index out of range in scale_col");
#endif
copyonwrite();
for(int j=0; j<nn; ++j) (*this)(j,i) *= f;
}
/***************************************************************************//**