tensor: tested full contractions
This commit is contained in:
15
t.cc
15
t.cc
@@ -3976,6 +3976,7 @@ cout <<t.dot(u)<<endl;
|
||||
|
||||
if(1)
|
||||
{
|
||||
//check full constractions
|
||||
int r,n;
|
||||
cin>>r>>n;
|
||||
INDEXGROUP shape;
|
||||
@@ -3986,13 +3987,17 @@ INDEXGROUP shape;
|
||||
shape.range=n;
|
||||
shape.offset=0;
|
||||
}
|
||||
Tensor<double> x(shape);
|
||||
x.randomize(1.);
|
||||
cout<<x;
|
||||
Tensor<double> x(shape); x.randomize(1.);
|
||||
Tensor<double> xf=x.flatten();
|
||||
cout <<xf;
|
||||
Tensor<double> y(shape); y.randomize(1.);
|
||||
Tensor<double> yf=y.flatten();
|
||||
|
||||
cout <<x.dot(x) <<" "<< xf.dot(xf)<<endl;
|
||||
Tensor<double> z = x.groupcontraction(0,y,0,1,false,true);
|
||||
INDEXLIST cl(r);
|
||||
for(int i=0; i<r; ++i) cl[i]={i,0};
|
||||
Tensor<double> zf = xf.contractions(cl,yf,cl,1,false,true);
|
||||
|
||||
cout <<x.dot(y) <<" "<< xf.dot(yf)<< " "<<z<<" "<<zf<<endl;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user