Skip to content

Shortcut to dot when calling tensordot #1201

Closed
@ricardoV94

Description

@ricardoV94

Description

Our implementation of tensordot copies the numpy one line by line, which makes use of reshape. Reshape is a tricky operator to reason symbolically because it can do many things ... #883

For the cases where tensordot is just the same as dot (matrix or mat-vec formats) we should shortcut to it. This happens when there are exactly one axis of reduction per input and both inputs are at most 2d (we can go further and talk about non-broadcastable dimensions). It may require a transpose before and after calling dot. This will also improve dot_general and einsum that make use of tensordot internally.

import pytensor
import pytensor.tensor as pt

x = pt.matrix("x")
y = pt.vector("y")

tensordot_out = pt.tensordot(x, y, axes=[[-1], [-1]])
tensordot_fn = pytensor.function([x, y], tensordot_out)
tensordot_fn.dprint(print_type=True)
Reshape{1} [id A] <Vector(float64, shape=(?,))> 6
 ├─ Dot22 [id B] <Matrix(float64, shape=(?, ?))> 5
 │  ├─ x [id C] <Matrix(float64, shape=(?, ?))>
 │  └─ Reshape{2} [id D] <Matrix(float64, shape=(?, ?))> 4
 │     ├─ y [id E] <Vector(float64, shape=(?,))>
 │     └─ MakeVector{dtype='int64'} [id F] <Vector(int64, shape=(2,))> 3
 │        ├─ Shape_i{0} [id G] <Scalar(int64, shape=())> 2
 │        │  └─ y [id E] <Vector(float64, shape=(?,))>
 │        └─ -1 [id H] <Scalar(int64, shape=())>
 └─ MakeVector{dtype='int64'} [id I] <Vector(int64, shape=(1,))> 1
    └─ Shape_i{0} [id J] <Scalar(int64, shape=())> 0
       └─ x [id C] <Matrix(float64, shape=(?, ?))>
dot_out = x @ y
dot_fn = pytensor.function([x, y], dot_out)
dot_fn.dprint(print_type=True)
CGemv{inplace} [id A] <Vector(float64, shape=(?,))> 2
 ├─ AllocEmpty{dtype='float64'} [id B] <Vector(float64, shape=(?,))> 1
 │  └─ Shape_i{0} [id C] <Scalar(int64, shape=())> 0
 │     └─ x [id D] <Matrix(float64, shape=(?, ?))>
 ├─ 1.0 [id E] <Scalar(float64, shape=())>
 ├─ x [id D] <Matrix(float64, shape=(?, ?))>
 ├─ y [id F] <Vector(float64, shape=(?,))>
 └─ 0.0 [id G] <Scalar(float64, shape=())>
import numpy as np
x_test = np.random.normal(size=(5, 3))
y_test = np.random.normal(size=(3,))

np.testing.assert_allclose(
  tensordot_fn(x_test, y_test),
  dot_fn(x_test, y_test),
)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions