suport for data preserving in NRVec::resize

This commit is contained in:
Jiri Pittner 2021-06-09 15:33:24 +02:00
parent 74835e5264
commit 5480de6ff2
3 changed files with 86 additions and 28 deletions

View File

@ -81,6 +81,10 @@ template<typename C> class NRSMat_from1;
template<typename C> class SparseMat; template<typename C> class SparseMat;
template<typename C> class SparseSMat; template<typename C> class SparseSMat;
template<typename C> class CSRMat; template<typename C> class CSRMat;
template<typename C> class NRPerm;
template<typename C> class CyclePerm;
template<typename C> class Partition;
template<typename C> class CompressedPartition;
//trick to allow real and imag part of complex as l-values //trick to allow real and imag part of complex as l-values
template<typename T> template<typename T>
@ -375,6 +379,11 @@ generate_traits(NRVec_from1)
generate_traits(SparseMat) generate_traits(SparseMat)
generate_traits(SparseSMat) //product leading to non-symmetric result not implemented generate_traits(SparseSMat) //product leading to non-symmetric result not implemented
generate_traits(CSRMat) generate_traits(CSRMat)
generate_traits(NRPerm)
generate_traits(CyclePerm)
generate_traits(Partition)
generate_traits(CompressedPartition)
#undef generate_traits #undef generate_traits

4
t.cc
View File

@ -2146,7 +2146,7 @@ int tot=p.generate_all_lex(printme);
cout <<"generated "<<tot<<endl; cout <<"generated "<<tot<<endl;
} }
if(0) if(1)
{ {
int n; int n;
cin >>n >>unitary_n; cin >>n >>unitary_n;
@ -2158,7 +2158,7 @@ if(tot!=partitions(n)) laerror("internal error in partition generation or enumer
if(space_dim!=longpow(unitary_n,n)) {cout<<space_dim<<" "<<ipow(unitary_n,n)<<endl;laerror("integer overflow or internal error in space dimensions");} if(space_dim!=longpow(unitary_n,n)) {cout<<space_dim<<" "<<ipow(unitary_n,n)<<endl;laerror("integer overflow or internal error in space dimensions");}
} }
if(1) if(0)
{ {
int n; int n;
cin >>n ; cin >>n ;

85
vec.h
View File

@ -299,8 +299,8 @@ public:
//! determine the number of elements //! determine the number of elements
inline int size() const; inline int size() const;
//! resize the current vector //! resize the current vector, optionally preserving data
void resize(const int n); void resize(const int n, const bool preserve=false);
//!deallocate the current vector //!deallocate the current vector
void dealloc(void) {resize(0);} void dealloc(void) {resize(0);}
@ -965,14 +965,25 @@ NRVec<T> & NRVec<T>::operator=(const NRVec<T> &rhs) {
* @param[in] n requested size * @param[in] n requested size
******************************************************************************/ ******************************************************************************/
template <typename T> template <typename T>
void NRVec<T>::resize(const int n) { void NRVec<T>::resize(const int n, const bool preserve)
{
#ifdef DEBUG #ifdef DEBUG
if(n < 0) laerror("illegal dimension"); if(n < 0) laerror("illegal dimension in NRVec::resize");
#endif #endif
if(count){ if(preserve && n<nn) laerror("cannot resize to smaller vector and preserve data");
if(n == 0){ T *vold=0;
if(--(*count) <= 0){ int nnold=0;
if(v){ bool preserved=false;
bool do_delete=false;
if(count) //we are allocated
{
if(n == 0) //just deallocate
{
if(--(*count) <= 0)
{
if(v)
{
#ifdef CUDALA #ifdef CUDALA
if(location == cpu){ if(location == cpu){
#endif #endif
@ -990,14 +1001,19 @@ void NRVec<T>::resize(const int n) {
v = 0; v = 0;
return; return;
} }
if(*count > 1) { if(*count > 1) //detach from shared data
{
(*count)--; (*count)--;
count = 0; count = 0;
vold=v;
v = 0; v = 0;
nnold=nn;
nn = 0; nn = 0;
preserved=true;
} }
} }
if(!count){ if(!count) //we were not allocated or we just detached
{
count = new int; count = new int;
*count = 1; *count = 1;
nn = n; nn = n;
@ -1009,25 +1025,58 @@ void NRVec<T>::resize(const int n) {
else else
v = (T*) gpualloc(nn*sizeof(T)); v = (T*) gpualloc(nn*sizeof(T));
#endif #endif
if(preserved && preserve) goto do_preserve;
return; return;
} }
// *count = 1 in this branch // *count == 1 in this branch
if (n != nn) { if (n == nn) return; //nothing to do
nnold=nn;
nn = n; nn = n;
#ifdef CUDALA #ifdef CUDALA
if(location == cpu){ if(location == cpu)
{
#endif #endif
if(preserve) {vold=v; do_delete=true;} else delete[] v;
delete[] v;
v = new T[nn]; v = new T[nn];
#ifdef CUDALA #ifdef CUDALA
}else{ }
else
gpufree(v); {
if(preserve) {vold=v; do_delete=true;} else gpufree(v);
v = (T*) gpualloc(nn*sizeof(T)); v = (T*) gpualloc(nn*sizeof(T));
} }
#endif #endif
if(!preserve) return;
//copy data from old location and zero excess allocated memory
do_preserve:
if(!preserve || !preserved) laerror("assertion failed in NRVec::resize");
// omit this check since we would need to have traits for presently unknown user defined classes
// if(!LA_traits<T>::is_plaindata()) laerror("do not know how to preserve non-plain data");
if(nnold>=nn) laerror("assertion2 failed in NRVec::resize");
#ifdef CUDALA
if(location == cpu)
{
#endif
for(int i=0; i<nnold; ++i) v[i]=vold[i]; //preserve even non-plain data classes
memset(v+nnold,0,(nn-nnold)*sizeof(T)); //just zero the new memory
if(do_delete) delete[] vold;
#ifdef CUDALA
} }
else
{
//!!!works only with plain data
cublasSetVector(nnold, sizeof(T), vold, 1, v, 1);
TEST_CUBLAS("cublasSetVector");
T a(0);
smart_gpu_set(nn-nnold, a, v+nnold);
if(do_delete) gpufree(vold);
}
#endif
return;
} }