working on tensor - outer product
This commit is contained in:
13
tensor.cc
13
tensor.cc
@@ -726,11 +726,16 @@ for(int i=0; i<il.size(); ++i)
|
||||
|
||||
//then the remaining groups with one index removed, if nonempty
|
||||
int ii=il.size();
|
||||
int emptied_groups=0;
|
||||
for(int i=0; i<oldshape.size(); ++i)
|
||||
if(oldshape[i].number>0)
|
||||
{
|
||||
newshape[ii++] = oldshape[i];
|
||||
}
|
||||
else
|
||||
++emptied_groups;
|
||||
|
||||
if(emptied_groups) newshape.resize(newshape.size()-emptied_groups,true);
|
||||
|
||||
Tensor<T> r(newshape);
|
||||
if(r.rank()!=rank()) laerror("internal error 2 in unwind_indces");
|
||||
@@ -1019,6 +1024,14 @@ return r;
|
||||
}
|
||||
|
||||
|
||||
//outer product, rhs indices will be the less significant than this
|
||||
template<typename T>
|
||||
Tensor<T> Tensor<T>::operator*(const Tensor &rhs) const
|
||||
{
|
||||
Tensor<T> r(rhs.shape.concat(shape));
|
||||
r.data= data.otimes2vec(rhs.data);
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user