suport for data preserving in NRVec::resize
This commit is contained in:
parent
74835e5264
commit
5480de6ff2
@ -81,6 +81,10 @@ template<typename C> class NRSMat_from1;
|
|||||||
template<typename C> class SparseMat;
|
template<typename C> class SparseMat;
|
||||||
template<typename C> class SparseSMat;
|
template<typename C> class SparseSMat;
|
||||||
template<typename C> class CSRMat;
|
template<typename C> class CSRMat;
|
||||||
|
template<typename C> class NRPerm;
|
||||||
|
template<typename C> class CyclePerm;
|
||||||
|
template<typename C> class Partition;
|
||||||
|
template<typename C> class CompressedPartition;
|
||||||
|
|
||||||
//trick to allow real and imag part of complex as l-values
|
//trick to allow real and imag part of complex as l-values
|
||||||
template<typename T>
|
template<typename T>
|
||||||
@ -375,6 +379,11 @@ generate_traits(NRVec_from1)
|
|||||||
generate_traits(SparseMat)
|
generate_traits(SparseMat)
|
||||||
generate_traits(SparseSMat) //product leading to non-symmetric result not implemented
|
generate_traits(SparseSMat) //product leading to non-symmetric result not implemented
|
||||||
generate_traits(CSRMat)
|
generate_traits(CSRMat)
|
||||||
|
generate_traits(NRPerm)
|
||||||
|
generate_traits(CyclePerm)
|
||||||
|
generate_traits(Partition)
|
||||||
|
generate_traits(CompressedPartition)
|
||||||
|
|
||||||
|
|
||||||
#undef generate_traits
|
#undef generate_traits
|
||||||
|
|
||||||
|
4
t.cc
4
t.cc
@ -2146,7 +2146,7 @@ int tot=p.generate_all_lex(printme);
|
|||||||
cout <<"generated "<<tot<<endl;
|
cout <<"generated "<<tot<<endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(0)
|
if(1)
|
||||||
{
|
{
|
||||||
int n;
|
int n;
|
||||||
cin >>n >>unitary_n;
|
cin >>n >>unitary_n;
|
||||||
@ -2158,7 +2158,7 @@ if(tot!=partitions(n)) laerror("internal error in partition generation or enumer
|
|||||||
if(space_dim!=longpow(unitary_n,n)) {cout<<space_dim<<" "<<ipow(unitary_n,n)<<endl;laerror("integer overflow or internal error in space dimensions");}
|
if(space_dim!=longpow(unitary_n,n)) {cout<<space_dim<<" "<<ipow(unitary_n,n)<<endl;laerror("integer overflow or internal error in space dimensions");}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(1)
|
if(0)
|
||||||
{
|
{
|
||||||
int n;
|
int n;
|
||||||
cin >>n ;
|
cin >>n ;
|
||||||
|
87
vec.h
87
vec.h
@ -299,8 +299,8 @@ public:
|
|||||||
//! determine the number of elements
|
//! determine the number of elements
|
||||||
inline int size() const;
|
inline int size() const;
|
||||||
|
|
||||||
//! resize the current vector
|
//! resize the current vector, optionally preserving data
|
||||||
void resize(const int n);
|
void resize(const int n, const bool preserve=false);
|
||||||
|
|
||||||
//!deallocate the current vector
|
//!deallocate the current vector
|
||||||
void dealloc(void) {resize(0);}
|
void dealloc(void) {resize(0);}
|
||||||
@ -965,14 +965,25 @@ NRVec<T> & NRVec<T>::operator=(const NRVec<T> &rhs) {
|
|||||||
* @param[in] n requested size
|
* @param[in] n requested size
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void NRVec<T>::resize(const int n) {
|
void NRVec<T>::resize(const int n, const bool preserve)
|
||||||
|
{
|
||||||
#ifdef DEBUG
|
#ifdef DEBUG
|
||||||
if(n < 0) laerror("illegal dimension");
|
if(n < 0) laerror("illegal dimension in NRVec::resize");
|
||||||
#endif
|
#endif
|
||||||
if(count){
|
if(preserve && n<nn) laerror("cannot resize to smaller vector and preserve data");
|
||||||
if(n == 0){
|
T *vold=0;
|
||||||
if(--(*count) <= 0){
|
int nnold=0;
|
||||||
if(v){
|
bool preserved=false;
|
||||||
|
bool do_delete=false;
|
||||||
|
|
||||||
|
if(count) //we are allocated
|
||||||
|
{
|
||||||
|
if(n == 0) //just deallocate
|
||||||
|
{
|
||||||
|
if(--(*count) <= 0)
|
||||||
|
{
|
||||||
|
if(v)
|
||||||
|
{
|
||||||
#ifdef CUDALA
|
#ifdef CUDALA
|
||||||
if(location == cpu){
|
if(location == cpu){
|
||||||
#endif
|
#endif
|
||||||
@ -990,14 +1001,19 @@ void NRVec<T>::resize(const int n) {
|
|||||||
v = 0;
|
v = 0;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if(*count > 1) {
|
if(*count > 1) //detach from shared data
|
||||||
|
{
|
||||||
(*count)--;
|
(*count)--;
|
||||||
count = 0;
|
count = 0;
|
||||||
|
vold=v;
|
||||||
v = 0;
|
v = 0;
|
||||||
|
nnold=nn;
|
||||||
nn = 0;
|
nn = 0;
|
||||||
|
preserved=true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(!count){
|
if(!count) //we were not allocated or we just detached
|
||||||
|
{
|
||||||
count = new int;
|
count = new int;
|
||||||
*count = 1;
|
*count = 1;
|
||||||
nn = n;
|
nn = n;
|
||||||
@ -1009,25 +1025,58 @@ void NRVec<T>::resize(const int n) {
|
|||||||
else
|
else
|
||||||
v = (T*) gpualloc(nn*sizeof(T));
|
v = (T*) gpualloc(nn*sizeof(T));
|
||||||
#endif
|
#endif
|
||||||
|
if(preserved && preserve) goto do_preserve;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// *count = 1 in this branch
|
// *count == 1 in this branch
|
||||||
if (n != nn) {
|
if (n == nn) return; //nothing to do
|
||||||
nn = n;
|
nnold=nn;
|
||||||
|
nn = n;
|
||||||
#ifdef CUDALA
|
#ifdef CUDALA
|
||||||
if(location == cpu){
|
if(location == cpu)
|
||||||
|
{
|
||||||
#endif
|
#endif
|
||||||
|
if(preserve) {vold=v; do_delete=true;} else delete[] v;
|
||||||
delete[] v;
|
|
||||||
v = new T[nn];
|
v = new T[nn];
|
||||||
#ifdef CUDALA
|
#ifdef CUDALA
|
||||||
}else{
|
}
|
||||||
|
else
|
||||||
gpufree(v);
|
{
|
||||||
|
if(preserve) {vold=v; do_delete=true;} else gpufree(v);
|
||||||
v = (T*) gpualloc(nn*sizeof(T));
|
v = (T*) gpualloc(nn*sizeof(T));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
if(!preserve) return;
|
||||||
|
|
||||||
|
//copy data from old location and zero excess allocated memory
|
||||||
|
do_preserve:
|
||||||
|
if(!preserve || !preserved) laerror("assertion failed in NRVec::resize");
|
||||||
|
// omit this check since we would need to have traits for presently unknown user defined classes
|
||||||
|
// if(!LA_traits<T>::is_plaindata()) laerror("do not know how to preserve non-plain data");
|
||||||
|
if(nnold>=nn) laerror("assertion2 failed in NRVec::resize");
|
||||||
|
|
||||||
|
#ifdef CUDALA
|
||||||
|
if(location == cpu)
|
||||||
|
{
|
||||||
|
#endif
|
||||||
|
for(int i=0; i<nnold; ++i) v[i]=vold[i]; //preserve even non-plain data classes
|
||||||
|
memset(v+nnold,0,(nn-nnold)*sizeof(T)); //just zero the new memory
|
||||||
|
if(do_delete) delete[] vold;
|
||||||
|
#ifdef CUDALA
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
//!!!works only with plain data
|
||||||
|
cublasSetVector(nnold, sizeof(T), vold, 1, v, 1);
|
||||||
|
TEST_CUBLAS("cublasSetVector");
|
||||||
|
T a(0);
|
||||||
|
smart_gpu_set(nn-nnold, a, v+nnold);
|
||||||
|
if(do_delete) gpufree(vold);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user