Closed
Description
Description
Our implementation of tensordot
copies the numpy one line by line, which makes use of reshape. Reshape is a tricky operator to reason symbolically because it can do many things ... #883
For the cases where tensordot is just the same as dot
(matrix or mat-vec formats) we should shortcut to it. This happens when there are exactly one axis of reduction per input and both inputs are at most 2d (we can go further and talk about non-broadcastable dimensions). It may require a transpose before and after calling dot
. This will also improve dot_general
and einsum
that make use of tensordot
internally.
import pytensor
import pytensor.tensor as pt
x = pt.matrix("x")
y = pt.vector("y")
tensordot_out = pt.tensordot(x, y, axes=[[-1], [-1]])
tensordot_fn = pytensor.function([x, y], tensordot_out)
tensordot_fn.dprint(print_type=True)
Reshape{1} [id A] <Vector(float64, shape=(?,))> 6
├─ Dot22 [id B] <Matrix(float64, shape=(?, ?))> 5
│ ├─ x [id C] <Matrix(float64, shape=(?, ?))>
│ └─ Reshape{2} [id D] <Matrix(float64, shape=(?, ?))> 4
│ ├─ y [id E] <Vector(float64, shape=(?,))>
│ └─ MakeVector{dtype='int64'} [id F] <Vector(int64, shape=(2,))> 3
│ ├─ Shape_i{0} [id G] <Scalar(int64, shape=())> 2
│ │ └─ y [id E] <Vector(float64, shape=(?,))>
│ └─ -1 [id H] <Scalar(int64, shape=())>
└─ MakeVector{dtype='int64'} [id I] <Vector(int64, shape=(1,))> 1
└─ Shape_i{0} [id J] <Scalar(int64, shape=())> 0
└─ x [id C] <Matrix(float64, shape=(?, ?))>
dot_out = x @ y
dot_fn = pytensor.function([x, y], dot_out)
dot_fn.dprint(print_type=True)
CGemv{inplace} [id A] <Vector(float64, shape=(?,))> 2
├─ AllocEmpty{dtype='float64'} [id B] <Vector(float64, shape=(?,))> 1
│ └─ Shape_i{0} [id C] <Scalar(int64, shape=())> 0
│ └─ x [id D] <Matrix(float64, shape=(?, ?))>
├─ 1.0 [id E] <Scalar(float64, shape=())>
├─ x [id D] <Matrix(float64, shape=(?, ?))>
├─ y [id F] <Vector(float64, shape=(?,))>
└─ 0.0 [id G] <Scalar(float64, shape=())>
import numpy as np
x_test = np.random.normal(size=(5, 3))
y_test = np.random.normal(size=(3,))
np.testing.assert_allclose(
tensordot_fn(x_test, y_test),
dot_fn(x_test, y_test),
)