tensor: special cases of ourter product with trivial tensor

This commit is contained in:
2025-11-11 19:48:44 +01:00
parent 98ef46ad47
commit 641b632df1
2 changed files with 23 additions and 1 deletions

23
t.cc
View File

@@ -4091,7 +4091,7 @@ for(int i=0;i<n+1;++i) for(int j=0;j<n+2;++j)
cout <<"Error= "<<e<<endl; cout <<"Error= "<<e<<endl;
} }
if(1) if(0)
{ {
int r=4; int r=4;
int n=5; int n=5;
@@ -4120,5 +4120,26 @@ Tensor<double> ss(s);
cout <<"Error= "<<ss-xc<<endl; cout <<"Error= "<<ss-xc<<endl;
} }
if(1)
{
int r=1;
int n=5;
NRVec<INDEXGROUP> sh(r);
for(int i=0; i<r; ++i)
{
sh[i].number=1;
sh[i].symmetry=0;
sh[i].range=n;
sh[i].offset=0;
}
Tensor<double> x(sh); x.randomize(1.);
Tensor<double> s1(2.);
Tensor<double> s2(3.);
cout <<s1*x<<endl;
cout <<x*s1<<endl;
cout<<s1*s2<<endl;
}
}//main }//main

View File

@@ -1831,6 +1831,7 @@ return r;
template<typename T> template<typename T>
void Tensor<T>::canonicalize_shape() void Tensor<T>::canonicalize_shape()
{ {
if(shape.size()==0) return;
const INDEXGROUP *sh = &(* const_cast<const NRVec<INDEXGROUP> *>(&shape))[0]; const INDEXGROUP *sh = &(* const_cast<const NRVec<INDEXGROUP> *>(&shape))[0];
for(int i=0; i<shape.size(); ++i) for(int i=0; i<shape.size(); ++i)
{ {