Skip to content

Commit b1fb4ec

Browse files
jessegrabowskizaxtax
authored andcommitted
Implement Einsum as OpFromGraph
Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> Co-authored-by: Rob Zinkov <zaxtax@users.noreply.github.com>
1 parent 28d9d4d commit b1fb4ec

File tree

10 files changed

+605
-9
lines changed

10 files changed

+605
-9
lines changed

pytensor/link/jax/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
import pytensor.link.jax.dispatch.scan
1515
import pytensor.link.jax.dispatch.sparse
1616
import pytensor.link.jax.dispatch.blockwise
17+
import pytensor.link.jax.dispatch.einsum
1718

1819
# isort: on

pytensor/link/jax/dispatch/einsum.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.einsum import Einsum
5+
6+
7+
@jax_funcify.register(Einsum)
8+
def jax_funcify_Einsum(op, **kwargs):
9+
subscripts = op.subscripts
10+
optimize = op.optimize
11+
12+
def einsum(*operands):
13+
return jnp.einsum(subscripts, *operands, optimize=optimize)
14+
15+
return einsum

pytensor/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,5 +153,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
153153
from pytensor.tensor.functional import vectorize
154154
# isort: on
155155

156+
from pytensor.tensor.einsum import einsum
157+
156158

157159
__all__ = ["random"] # noqa: F405

pytensor/tensor/basic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1974,7 +1974,12 @@ def transpose(x, axes=None):
19741974
_x = as_tensor_variable(x)
19751975

19761976
if axes is None:
1977-
axes = list(range((_x.type.ndim - 1), -1, -1))
1977+
axes = tuple(range((_x.type.ndim - 1), -1, -1))
1978+
1979+
if tuple(axes) == tuple(range(len(axes))):
1980+
# No-op
1981+
return _x
1982+
19781983
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
19791984

19801985
if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
@@ -3957,6 +3962,10 @@ def moveaxis(
39573962
source = normalize_axis_tuple(source, a.ndim, "source")
39583963
destination = normalize_axis_tuple(destination, a.ndim, "destination")
39593964

3965+
if source == destination:
3966+
# It's a no-op
3967+
return a
3968+
39603969
if len(source) != len(destination):
39613970
raise ValueError(
39623971
"`source` and `destination` arguments must have the same number of elements"
@@ -4271,9 +4280,7 @@ def atleast_Nd(
42714280
atleast_3d = partial(atleast_Nd, n=3)
42724281

42734282

4274-
def expand_dims(
4275-
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
4276-
) -> TensorVariable:
4283+
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
42774284
"""Expand the shape of an array.
42784285
42794286
Insert a new axis that will appear at the `axis` position in the expanded
@@ -4292,7 +4299,7 @@ def expand_dims(
42924299
"""
42934300
a = as_tensor(a)
42944301

4295-
if not isinstance(axis, tuple | list):
4302+
if not isinstance(axis, Sequence):
42964303
axis = (axis,)
42974304

42984305
out_ndim = len(axis) + a.ndim

0 commit comments

Comments
 (0)