progressing implementation of permutations

This commit is contained in:
Jiri Pittner 2021-05-13 16:45:10 +02:00
parent 01665674c5
commit 60e8a379f5
9 changed files with 201 additions and 6 deletions

53
mat.cc
View File

@ -3075,6 +3075,59 @@ NRMat<T>& NRMat<T>::swap_rows_cols(){
return *this; return *this;
} }
//apply permutations
template<typename T>
const NRMat<T> NRMat<T>::permute_rows(const NRPerm<int> &p) const
{
#ifdef DEBUG
if(!p.is_valid()) laerror("invalid permutation of matrix");
#endif
int n=p.size();
if(n!=nn) laerror("incompatible permutation and matrix");
#ifdef CUDALA
if(this->getlocation() != cpu || p.getlocation() != cpu ) laerror("permutations can be done only in CPU memory");
#endif
NRMat<T> r(nn,mm);
for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=0; j<mm; ++j) r(i-1,j) = (*this)(pi,j);}
return r;
}
template<typename T>
const NRMat<T> NRMat<T>::permute_cols(const NRPerm<int> &p) const
{
#ifdef DEBUG
if(!p.is_valid()) laerror("invalid permutation of matrix");
#endif
int n=p.size();
if(n!=mm) laerror("incompatible permutation and matrix");
#ifdef CUDALA
if(this->getlocation() != cpu || p.getlocation() != cpu ) laerror("permutations can be done only in CPU memory");
#endif
NRMat<T> r(nn,mm);
for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=0; j<nn; ++j) r(j,i-1) = (*this)(j,pi);}
return r;
}
template<typename T>
const NRMat<T> NRMat<T>::permute_both(const NRPerm<int> &p, const NRPerm<int> &q) const
{
#ifdef DEBUG
if(!p.is_valid() || !q.is_valid() ) laerror("invalid permutation of matrix");
#endif
int n=p.size();
int m=q.size();
if(n!=nn ||m!=mm) laerror("incompatible permutation and matrix");
#ifdef CUDALA
if(this->getlocation() != cpu || p.getlocation() != cpu ) laerror("permutations can be done only in CPU memory");
#endif
NRMat<T> r(nn,mm);
for(int i=1; i<=n; ++i) {int pi=p[i]-1; for(int j=1; j<=m; ++j) r(i-1,j-1) = (*this)(pi,q[j]-1);}
return r;
}
/***************************************************************************//** /***************************************************************************//**
* forced instantization in the corresponding object file * forced instantization in the corresponding object file
******************************************************************************/ ******************************************************************************/

11
mat.h
View File

@ -24,6 +24,10 @@
namespace LA { namespace LA {
//forward declaration
template<typename T> class NRPerm;
/***************************************************************************//** /***************************************************************************//**
* \brief NRMat<T> class template implementing the matrix interface * \brief NRMat<T> class template implementing the matrix interface
* @see NRVec<T>, NRSMat<T> * @see NRVec<T>, NRSMat<T>
@ -111,6 +115,12 @@ public:
//! ensure that the data of this matrix are referenced exactly once //! ensure that the data of this matrix are referenced exactly once
void copyonwrite(bool detachonly=false); void copyonwrite(bool detachonly=false);
//! permute matrix elements
const NRMat permute_rows(const NRPerm<int> &p) const;
const NRMat permute_cols(const NRPerm<int> &p) const;
const NRMat permute_both(const NRPerm<int> &p, const NRPerm<int> &q) const;
/***************************************************************************//** /***************************************************************************//**
* routines for CUDA related stuff * routines for CUDA related stuff
* \li <code>getlocation()</code> gets the protected data member location * \li <code>getlocation()</code> gets the protected data member location
@ -376,6 +386,7 @@ public:
#include "smat.h" #include "smat.h"
#include "sparsemat.h" #include "sparsemat.h"
#include "sparsesmat.h" #include "sparsesmat.h"
#include "permutation.h"
namespace LA { namespace LA {

View File

@ -19,6 +19,32 @@
#include "permutation.h" #include "permutation.h"
namespace LA { namespace LA {
template <typename T>
void NRPerm<T>::identity()
{
T n=this->size();
#ifdef DEBUG
if(n<0) laerror("invalid permutation size");
#endif
if(n==0) return;
for(T i=1; i<=n; ++i) (*this)[i]=i;
}
template <typename T>
bool NRPerm<T>::is_identity() const
{
T n=this->size();
#ifdef DEBUG
if(n<0) laerror("invalid permutation size");
#endif
if(n==0) return 1;
for(T i=1; i<=n; ++i) if((*this)[i]!=i) return 0;
return 1;
}
template <typename T> template <typename T>
bool NRPerm<T>::is_valid() const bool NRPerm<T>::is_valid() const
{ {
@ -39,6 +65,10 @@ return 1;
template <typename T> template <typename T>
NRPerm<T> NRPerm<T>::inverse() const NRPerm<T> NRPerm<T>::inverse() const
{ {
#ifdef DEBUG
if(!this->is_valid()) laerror("inverse of invalid permutation");
#endif
NRPerm<T> q(this->size()); NRPerm<T> q(this->size());
for(T i=1; i<=this->size(); ++i) q[(*this)[i]]=i; for(T i=1; i<=this->size(); ++i) q[(*this)[i]]=i;
return q; return q;
@ -48,15 +78,44 @@ return q;
template <typename T> template <typename T>
NRPerm<T> NRPerm<T>::operator*(const NRPerm<T> q) const NRPerm<T> NRPerm<T>::operator*(const NRPerm<T> q) const
{ {
#ifdef DEBUG
if(!this->is_valid() || !q.is_valid()) laerror("multiplication of invalid permutation");
#endif
T n=this->size(); T n=this->size();
if(n!=q.size()) laerror("product of incompatible permutations"); if(n!=q.size()) laerror("product of incompatible permutations");
NRPerm<T> r(n); NRPerm<T> r(n);
for(T i=1; i<=n; ++i) r[i] = q[(*this)[i]]; for(T i=1; i<=n; ++i) r[i] = (*this)[q[i]];
return r; return r;
} }
template <typename T>
NRPerm<T> NRPerm<T>::conjugate_by(const NRPerm<T> q) const
{
#ifdef DEBUG
if(!this->is_valid() || !q.is_valid()) laerror("multiplication of invalid permutation");
#endif
T n=this->size();
if(n!=q.size()) laerror("product of incompatible permutations");
NRPerm<T> qi=q.inverse();
NRPerm<T> r(n);
for(T i=1; i<=n; ++i) r[i] = qi[(*this)[q[i]]];
return r;
}
template <typename T>
int NRPerm<T>::parity() const
{
if(!this->is_valid()) return 0;
T n=this->size();
if(n==1) return 1;
T count=0;
for(T i=2;i<=n;i++) for(T j=1;j<i;j++) if((*this)[j]>(*this)[i]) count++;
return (count&1)? -1:1;
}

View File

@ -28,6 +28,9 @@
namespace LA { namespace LA {
//forward declaration
template <typename T> class NRVec_from1;
template <typename T> template <typename T>
class NRPerm : public NRVec_from1<T> { class NRPerm : public NRVec_from1<T> {
public: public:
@ -39,17 +42,21 @@ public:
NRPerm(const T *a, const int n): NRVec_from1<T>(a, n) {}; NRPerm(const T *a, const int n): NRVec_from1<T>(a, n) {};
//specific operations //specific operations
void identity();
bool is_valid() const; //is it really a permutation bool is_valid() const; //is it really a permutation
bool is_identity() const;
NRPerm inverse() const; NRPerm inverse() const;
NRPerm operator*(const NRPerm rhs) const; NRPerm operator*(const NRPerm q) const; //q is rhs and applied first, this applied second
NRPerm conjugate_by(const NRPerm q) const; //q^-1 p q
int parity() const;
//TODO: //TODO:
//@@@conjugate by q //@@@permutation matrix
//@@@permgener //@@@permgener
//@@@lex rank
//@@@next permutation //@@@next permutation
//@@@lex rank
//@@@inversion tables //@@@inversion tables
//@@@parity
//@@@conversion to cycle structure and back //@@@conversion to cycle structure and back
}; };

18
smat.cc
View File

@ -303,6 +303,24 @@ void NRSMat<T>::fscanf(FILE *f, const char *format) {
laerror("NRSMat<T>::fscanf(FILE *, const char *) - unable to read matrix element"); laerror("NRSMat<T>::fscanf(FILE *, const char *) - unable to read matrix element");
} }
//apply permutation
template <typename T>
const NRSMat<T> NRSMat<T>::permute(const NRPerm<int> &p) const
{
#ifdef DEBUG
if(!p.is_valid()) laerror("invalid permutation of smatrix");
#endif
int n=p.size();
if(n!=(*this).size()) laerror("incompatible permutation and smatrix");
#ifdef CUDALA
if(this->getlocation() != cpu || p.getlocation() != cpu ) laerror("permutations can be done only in CPU memory");
#endif
NRSMat<T> r(n);
for(int i=1; i<=n; ++i) {int pi = p[i]-1; r(i-1,i-1) = (*this)(pi,pi);}
return r;
}
/***************************************************************************//** /***************************************************************************//**
* multiply this real double-precision symmetric matrix \f$S\f$ stored in packed form * multiply this real double-precision symmetric matrix \f$S\f$ stored in packed form

6
smat.h
View File

@ -27,6 +27,9 @@
namespace LA { namespace LA {
#define NN2 ((size_t)nn*(nn+1)/2) #define NN2 ((size_t)nn*(nn+1)/2)
//forward declaration
template<typename T> class NRPerm;
/***************************************************************************//** /***************************************************************************//**
* This class implements a general symmetric or hermitian matrix the elements * This class implements a general symmetric or hermitian matrix the elements
@ -89,6 +92,8 @@ public:
//! assign scalar value to diagonal elements //! assign scalar value to diagonal elements
NRSMat & operator=(const T &a); NRSMat & operator=(const T &a);
//! permute matrix elements
const NRSMat permute(const NRPerm<int> &p) const;
inline int getcount() const {return count?*count:0;} inline int getcount() const {return count?*count:0;}
@ -176,6 +181,7 @@ public:
//due to mutual includes this has to be after full class declaration //due to mutual includes this has to be after full class declaration
#include "vec.h" #include "vec.h"
#include "mat.h" #include "mat.h"
#include "permutation.h"
namespace LA { namespace LA {

16
t.cc
View File

@ -22,6 +22,7 @@
#include "la.h" #include "la.h"
#include "vecmat3.h" #include "vecmat3.h"
#include "quaternion.h" #include "quaternion.h"
#include "permutation.h"
using namespace std; using namespace std;
using namespace LA_Vecmat3; using namespace LA_Vecmat3;
@ -1980,11 +1981,24 @@ cout <<"normquat2euler test "<<endl<<qq<<endl<<xqq<<endl<<qq-xqq<<endl;
} }
if(1) if(0)
{ {
NRVec<double> a,b,c; NRVec<double> a,b,c;
cin >>a>>b; cin >>a>>b;
c=a+b; c=a+b;
cout<<c; cout<<c;
} }
if(1)
{
NRPerm<int> p;
cin >>p;
int n=p.size();
NRVec_from1<double> v(n);
int i;
for(i=1; i<=n; ++i) v[i]=10.*i;
cout <<v.permute(p);
}
} }

16
vec.cc
View File

@ -833,6 +833,22 @@ NRVec<std::complex<double> > complexify(const NRVec<double> &rhs) {
return r; return r;
} }
template<typename T>
const NRVec<T> NRVec<T>::permute(const NRPerm<int> &p) const
{
#ifdef DEBUG
if(!p.is_valid()) laerror("invalid permutation of vector");
#endif
int n=p.size();
if(n!=(*this).size()) laerror("incompatible permutation and vector");
#ifdef CUDALA
if(this->getlocation() != cpu || p.getlocation() != cpu ) laerror("permutations can be done only in CPU memory");
#endif
NRVec<T> r(n);
for(int i=1; i<=n; ++i) r[i-1] = v[p[i]-1];
return r;
}
/***************************************************************************//** /***************************************************************************//**
* forced instantization in the corespoding object file * forced instantization in the corespoding object file
******************************************************************************/ ******************************************************************************/

11
vec.h
View File

@ -30,6 +30,8 @@ namespace LA {
template <typename T> void lawritemat(FILE *file, const T *a, int r, int c, template <typename T> void lawritemat(FILE *file, const T *a, int r, int c,
const char *form0, int nodim, int modulo, int issym); const char *form0, int nodim, int modulo, int issym);
template <typename T> class NRPerm;
/***************************************************************************//** /***************************************************************************//**
* static constants used in several cblas-routines * static constants used in several cblas-routines
******************************************************************************/ ******************************************************************************/
@ -260,6 +262,9 @@ public:
return sum; return sum;
}; };
//! permute vector elements
const NRVec permute(const NRPerm<int> &p) const;
//! compute the sum of the absolute values of the elements of this vector //! compute the sum of the absolute values of the elements of this vector
inline const typename LA_traits<T>::normtype asum() const; inline const typename LA_traits<T>::normtype asum() const;
@ -382,6 +387,12 @@ public:
inline T& operator[] (const int i); inline T& operator[] (const int i);
}; };
}//namespace
//needs NRVec_from1
#include "permutation.h"
namespace LA {
/***************************************************************************//** /***************************************************************************//**
* indexing operator giving the element at given position with range checking in * indexing operator giving the element at given position with range checking in
* the DEBUG mode * the DEBUG mode