tensor class -contraction
This commit is contained in:
		
							parent
							
								
									5c6cb43c61
								
							
						
					
					
						commit
						27cc7854f5
					
				
							
								
								
									
										58
									
								
								t.cc
									
									
									
									
									
								
							
							
						
						
									
										58
									
								
								t.cc
									
									
									
									
									
								
							| @ -3270,7 +3270,7 @@ for(int i=0; i<n; ++i) | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if(1) | if(0) | ||||||
| { | { | ||||||
| int n=5; | int n=5; | ||||||
| INDEXGROUP g; | INDEXGROUP g; | ||||||
| @ -3294,4 +3294,60 @@ cout <<e; | |||||||
| cout <<eu; | cout <<eu; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | if(1) | ||||||
|  | { | ||||||
|  | int n=5; | ||||||
|  | INDEXGROUP ag; | ||||||
|  | ag.number=4; | ||||||
|  | ag.symmetry= 1; | ||||||
|  | ag.offset=0; | ||||||
|  | ag.range=n; | ||||||
|  | 
 | ||||||
|  | Tensor<double> a(ag); | ||||||
|  | a.randomize(1.); | ||||||
|  | 
 | ||||||
|  | INDEXGROUP bg; | ||||||
|  | bg.number=3; | ||||||
|  | bg.symmetry= 0; | ||||||
|  | bg.offset=0; | ||||||
|  | bg.range=n; | ||||||
|  | 
 | ||||||
|  | Tensor<double> b(bg); | ||||||
|  | b.randomize(1.); | ||||||
|  | 
 | ||||||
|  | Tensor<double> cc = a.contraction(0,0,b,0,1); | ||||||
|  | cout <<cc; | ||||||
|  | 
 | ||||||
|  | INDEXGROUP cga; | ||||||
|  | cga.number=3; | ||||||
|  | cga.symmetry= 1; | ||||||
|  | cga.offset=0; | ||||||
|  | cga.range=n; | ||||||
|  | 
 | ||||||
|  | INDEXGROUP cgb; | ||||||
|  | cgb.number=2; | ||||||
|  | cgb.symmetry= 0; | ||||||
|  | cgb.offset=0; | ||||||
|  | cgb.range=n; | ||||||
|  | 
 | ||||||
|  | NRVec<INDEXGROUP> shape({cgb,cga}); | ||||||
|  | 
 | ||||||
|  | Tensor<double> c(shape); | ||||||
|  | c.clear(); | ||||||
|  | 
 | ||||||
|  | for(int i=0; i<n; ++i) | ||||||
|  | 	for(int j=0; j<=i; ++j) | ||||||
|  | 		for(int k=0; k<=j; ++k) | ||||||
|  | 			for(int l=0; l<n; ++l) | ||||||
|  | 				for(int m=0; m<n; ++m) | ||||||
|  | 					{ | ||||||
|  | 					for(int p=0; p<n; ++p) | ||||||
|  | 						c.lhs(m,l,k,j,i) += a(p,i,j,k) * b(m,p,l); | ||||||
|  | 					if(abs(c(m,l,k,j,i)-cc(m,l,k,j,i))>1e-13) laerror("internal error in conntraction"); | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | //cout <<c;
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
| } | } | ||||||
|  | |||||||
							
								
								
									
										69
									
								
								tensor.cc
									
									
									
									
									
								
							
							
						
						
									
										69
									
								
								tensor.cc
									
									
									
									
									
								
							| @ -589,6 +589,7 @@ if(shape[group].number==1) //single index in the group | |||||||
| 	NRPerm<int> p(shape.size()); | 	NRPerm<int> p(shape.size()); | ||||||
| 	p[1]= 1+group; | 	p[1]= 1+group; | ||||||
| 	int ii=1; | 	int ii=1; | ||||||
|  | 	if(ii==1+group) ii++; //skip this
 | ||||||
| 	for(int i=2; i<=shape.size(); ++i)  | 	for(int i=2; i<=shape.size(); ++i)  | ||||||
| 		{ | 		{ | ||||||
| 		p[i]=ii++; | 		p[i]=ii++; | ||||||
| @ -625,12 +626,17 @@ if(r.rank()!=rank()) laerror("internal error 2 in unwind_index"); | |||||||
| NRPerm<int> indexperm(rank()); | NRPerm<int> indexperm(rank()); | ||||||
| indexperm[1]=flatindex+1; | indexperm[1]=flatindex+1; | ||||||
| int ii=1; | int ii=1; | ||||||
|  | if(ii==flatindex+1) ii++; | ||||||
| for(int i=2; i<=rank(); ++i)  | for(int i=2; i<=rank(); ++i)  | ||||||
| 	{ | 	{ | ||||||
| 	indexperm[i] = ii++; | 	indexperm[i] = ii++; | ||||||
| 	if(ii==flatindex+1) ii++; //skip this
 | 	if(ii==flatindex+1) ii++; //skip this
 | ||||||
| 	} | 	} | ||||||
| if(!indexperm.is_valid()) laerror("internal error 3 in unwind_index"); | if(!indexperm.is_valid())  | ||||||
|  | 	{ | ||||||
|  | 	std::cout << "indexperm = "<<indexperm<<std::endl; | ||||||
|  | 	laerror("internal error 3 in unwind_index"); | ||||||
|  | 	} | ||||||
| 
 | 
 | ||||||
| //loop recursively and do the unwinding
 | //loop recursively and do the unwinding
 | ||||||
| help_tt<T> = this; | help_tt<T> = this; | ||||||
| @ -640,6 +646,67 @@ return r; | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | template<typename T> | ||||||
|  | static void auxmatmult(int nn, int mm, int kk, T *r, T *a, T *b,  T alpha=1, T beta=0) //R(nn,mm) = A * B^T
 | ||||||
|  | { | ||||||
|  | for(int i=0; i<nn; ++i) for(int j=0; j<mm; ++j) | ||||||
|  | 	{ | ||||||
|  | 	if(beta==0)  r[i*mm+j]=0; else r[i*mm+j] *= beta; | ||||||
|  | 	for(int k=0; k<kk; ++k) r[i*mm+j] += alpha * a[i*kk+k] * b[j*kk+k]; | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | template<> | ||||||
|  | void auxmatmult<double>(int nn, int mm, int kk, double *r, double *a, double *b, double alpha, double beta) | ||||||
|  | { | ||||||
|  | cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, alpha, a, kk, b, kk, beta, r, mm); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template<> | ||||||
|  | void auxmatmult<std::complex<double> >(int nn, int mm, int kk, std::complex<double> *r, std::complex<double> *a, std::complex<double> *b,  std::complex<double> alpha,  std::complex<double> beta) | ||||||
|  | { | ||||||
|  | cblas_zgemm(CblasRowMajor, CblasNoTrans, CblasTrans, nn, mm, kk, &alpha, a, kk, b, kk, &beta, r, mm); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | //Conntraction could be implemented without the temporary storage for unwinding, but then we would need
 | ||||||
|  | //double recursion over indices of both tensors. Hopefully using the matrix multiplication here
 | ||||||
|  | //makes it also more efficient, even for (anti)symmetric indices
 | ||||||
|  | //The index unwinding is unfortunately a big burden, and in principle could be eliminated in case of non-symmetric indices
 | ||||||
|  | //
 | ||||||
|  | template<typename T> | ||||||
|  | Tensor<T> Tensor<T>::contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha) const | ||||||
|  | { | ||||||
|  | if(group<0||group>=shape.size()) laerror("wrong group number in contraction"); | ||||||
|  | if(rhsgroup<0||rhsgroup>=rhs.shape.size()) laerror("wrong rhsgroup number in contraction"); | ||||||
|  | if(index<0||index>=shape[group].number)  laerror("wrong index number in conntraction"); | ||||||
|  | if(rhsindex<0||rhsindex>=rhs.shape[rhsgroup].number)  laerror("wrong index number in conntraction"); | ||||||
|  | if(shape[group].offset != rhs.shape[rhsgroup].offset) laerror("incompatible index offset in contraction"); | ||||||
|  | if(shape[group].range != rhs.shape[rhsgroup].range) laerror("incompatible index range in contraction"); | ||||||
|  | 
 | ||||||
|  | Tensor<T> u = unwind_index(group,index); | ||||||
|  | Tensor<T> rhsu = rhs.unwind_index(rhsgroup,rhsindex); | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | NRVec<indexgroup> newshape(u.shape.size()+rhsu.shape.size()-2); | ||||||
|  | int ii=0; | ||||||
|  | for(int i=1; i<rhsu.shape.size(); ++i) newshape[ii++] = rhsu.shape[i]; | ||||||
|  | for(int i=1; i<u.shape.size(); ++i) newshape[ii++] = u.shape[i]; //this tensor will have more significant indices than the rhs one
 | ||||||
|  | 
 | ||||||
|  | Tensor<T> r(newshape); | ||||||
|  | int nn,mm,kk; | ||||||
|  | kk=u.groupsizes[0]; | ||||||
|  | if(kk!=rhsu.groupsizes[0]) laerror("internal error in contraction"); | ||||||
|  | nn=1; for(int i=1; i<u.shape.size(); ++i) nn*= u.groupsizes[i]; | ||||||
|  | mm=1; for(int i=1; i<rhsu.shape.size(); ++i) mm*= rhsu.groupsizes[i]; | ||||||
|  | auxmatmult<T>(nn,mm,kk,&r.data[0],&u.data[0], &rhsu.data[0],alpha); | ||||||
|  | return r; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| template class Tensor<double>; | template class Tensor<double>; | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								tensor.h
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								tensor.h
									
									
									
									
									
								
							| @ -179,10 +179,10 @@ public: | |||||||
| 
 | 
 | ||||||
| 	Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
 | 	Tensor permute_index_groups(const NRPerm<int> &p) const; //rearrange the tensor storage permuting index groups as a whole
 | ||||||
| 	Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one
 | 	Tensor unwind_index(int group, int index) const; //separate an index from a group and expand it to full range as the least significant one
 | ||||||
|  | 	Tensor contraction(int group, int index, const Tensor &rhs, int rhsgroup, int rhsindex, T alpha=1) const; | ||||||
| 
 | 
 | ||||||
| 	//@@@ general antisymmetrization operator Kucharski style - or that will be left to a code generator?
 | 	//@@@ general antisymmetrization operator Kucharski style - or that will be left to a code generator?
 | ||||||
| 	//@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric group - obecne symmetry change krome +1 na -1 vse mozne
 | 	//@@@symmetrize a group, antisymmetrize a group, expand a (anti)symmetric group - obecne symmetry change krome +1 na -1 vse mozne
 | ||||||
| 	//@@@contraction
 |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user