Skip to content

Commit dd08faa

Browse files
jessegrabowskiricardoV94
authored andcommitted
Implement Einsum as OpFromGraph
1 parent 14651fb commit dd08faa

File tree

8 files changed

+363
-5
lines changed

8 files changed

+363
-5
lines changed

pytensor/compile/builders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import cast
88

99
import pytensor.tensor as pt
10-
from pytensor import function
10+
from pytensor.compile.function import function
1111
from pytensor.compile.function.pfunc import rebuild_collect_shared
1212
from pytensor.compile.mode import optdb
1313
from pytensor.compile.sharedvalue import SharedVariable

pytensor/link/jax/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
import pytensor.link.jax.dispatch.scan
1515
import pytensor.link.jax.dispatch.sparse
1616
import pytensor.link.jax.dispatch.blockwise
17+
import pytensor.link.jax.dispatch.einsum
1718

1819
# isort: on

pytensor/link/jax/dispatch/einsum.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.einsum import Einsum
5+
6+
7+
@jax_funcify.register(Einsum)
8+
def jax_funcify_Einsum(op, **kwargs):
9+
subscripts = op.subscripts
10+
optimize = op.optimize
11+
12+
def einsum(*operands):
13+
return jnp.einsum(subscripts, *operands, optimize=optimize)
14+
15+
return einsum

pytensor/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,5 +153,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
153153
from pytensor.tensor.functional import vectorize
154154
# isort: on
155155

156+
from pytensor.tensor.einsum import einsum
157+
156158

157159
__all__ = ["random"] # noqa: F405

pytensor/tensor/basic.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4240,17 +4240,15 @@ def atleast_Nd(
42404240
atleast_3d = partial(atleast_Nd, n=3)
42414241

42424242

4243-
def expand_dims(
4244-
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
4245-
) -> TensorVariable:
4243+
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
42464244
"""Expand the shape of an array.
42474245
42484246
Insert a new axis that will appear at the `axis` position in the expanded
42494247
array shape.
42504248
"""
42514249
a = as_tensor(a)
42524250

4253-
if not isinstance(axis, tuple | list):
4251+
if not isinstance(axis, Sequence):
42544252
axis = (axis,)
42554253

42564254
out_ndim = len(axis) + a.ndim

pytensor/tensor/einsum.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import collections
2+
from collections.abc import Sequence
3+
from functools import reduce
4+
from itertools import pairwise
5+
6+
from numpy.core.numeric import normalize_axis_index # type: ignore
7+
8+
from pytensor.compile.builders import OpFromGraph
9+
from pytensor.tensor.basic import (
10+
arange,
11+
expand_dims,
12+
get_vector_length,
13+
stack,
14+
transpose,
15+
where,
16+
)
17+
from pytensor.tensor.extra_ops import broadcast_to
18+
from pytensor.tensor.math import and_, eq, tensordot
19+
from pytensor.tensor.shape import shape_padright
20+
from pytensor.tensor.variable import TensorVariable
21+
22+
23+
class Einsum(OpFromGraph):
24+
"""
25+
Wrapper Op for Einsum graphs
26+
"""
27+
28+
__props__ = ("subscripts", "optimize")
29+
30+
def __init__(
31+
self, *args, subscripts: str, optimize: str | None = "optimal", **kwargs
32+
):
33+
self.subscripts = subscripts
34+
self.optimize = optimize
35+
super().__init__(*args, **kwargs)
36+
37+
38+
def _iota(shape: TensorVariable, axis: int) -> TensorVariable:
39+
axis = normalize_axis_index(axis, get_vector_length(shape))
40+
values = arange(shape[axis])
41+
return broadcast_to(shape_padright(values, axis), shape)
42+
43+
44+
def _delta(shape, axes: Sequence[int]) -> TensorVariable:
45+
"""This utility function exists for creating Kronecker delta arrays."""
46+
base_shape = stack([shape[axis] for axis in axes])
47+
iotas = [_iota(base_shape, i) for i in range(len(axes))]
48+
eyes = [eq(i1, i2) for i1, i2 in pairwise(iotas)]
49+
result = reduce(and_, eyes)
50+
return broadcast_to(expand_dims(result, tuple(axes)), shape)
51+
52+
53+
def _removechars(s, chars):
54+
return s.translate(str.maketrans(dict.fromkeys(chars)))
55+
56+
57+
def einsum(subscripts: str, *operands):
58+
"""
59+
Multiplication and summation of tensors using the Einstein summation convention.
60+
61+
# TODO: Write docs
62+
63+
Parameters
64+
----------
65+
subscripts: str
66+
67+
operands: sequence of TensorVariable
68+
Tensors to be multiplied and summed.
69+
70+
Returns
71+
-------
72+
TensorVariable
73+
The result of the einsum operation.
74+
"""
75+
# TODO: Is this doing something clever about unknown shapes?
76+
# contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
77+
# using einsum_call=True here is an internal api for opt_einsum... sorry
78+
79+
# TODO: Handle None static shapes
80+
# TODO: Do we need this as dependency?
81+
from opt_einsum import contract_path
82+
83+
shapes = [operand.type.shape for operand in operands]
84+
85+
# TODE: Do fast path at creation time, and optimize only in fast_run
86+
_, contraction_list = contract_path(
87+
subscripts,
88+
*shapes,
89+
einsum_call=True,
90+
use_blas=True,
91+
optimize="optimal",
92+
shapes=True,
93+
)
94+
95+
def sum_uniques(
96+
operand: TensorVariable, names: str, uniques: list[str]
97+
) -> tuple[TensorVariable, str]:
98+
if uniques:
99+
axes = [names.index(name) for name in uniques]
100+
operand = operand.sum(axes)
101+
names = _removechars(names, uniques)
102+
return operand, names
103+
104+
def sum_repeats(
105+
operand: TensorVariable,
106+
names: str,
107+
counts: collections.Counter,
108+
keep_names: str,
109+
) -> tuple[TensorVariable, str]:
110+
for name, count in counts.items():
111+
if count > 1:
112+
axes = [i for i, n in enumerate(names) if n == name]
113+
eye = _delta(operand.shape, axes)
114+
operand = where(eye, operand, operand.zeros_like())
115+
if name not in keep_names:
116+
operand = operand.sum(axes)
117+
names = names.replace(name, "")
118+
else:
119+
operand = operand.sum(axes[:-1])
120+
names = names.replace(name, "", count - 1)
121+
return operand, names
122+
123+
# def filter_singleton_dims(operand, names, other_shape, other_names):
124+
# eq = core.definitely_equal
125+
# keep = [
126+
# not eq(operand.shape[i], 1) or j == -1 or eq(other_shape[j], 1)
127+
# for i, j in enumerate(map(other_names.find, names))
128+
# ]
129+
# sqez_axes, keep_axes = partition_list(keep, list(range(operand.ndim)))
130+
# return lax.squeeze(operand, sqez_axes), "".join(names[i] for i in keep_axes)
131+
132+
einsum_operands = list(operands) # So we can pop
133+
for operand_indices, contracted_names, einstr, _, _ in contraction_list:
134+
contracted_names = sorted(contracted_names)
135+
assert len(contracted_names) == len(
136+
set(contracted_names)
137+
), "The set was needed!"
138+
139+
input_str, result_names = einstr.split("->")
140+
input_names = input_str.split(",")
141+
142+
# switch on the number of operands to be processed in this loop iteration.
143+
# every case here sets 'operand' and 'names'.
144+
if len(operand_indices) == 1:
145+
operand = einsum_operands.pop(operand_indices[0])
146+
(names,) = input_names
147+
counts = collections.Counter(names)
148+
149+
# sum out unique contracted indices with a single reduce-sum
150+
uniques = [name for name in contracted_names if counts[name] == 1]
151+
operand, names = sum_uniques(operand, names, uniques)
152+
153+
# for every repeated index, do a contraction against an identity matrix
154+
operand, names = sum_repeats(operand, names, counts, result_names)
155+
156+
elif len(operand_indices) == 2:
157+
lhs, rhs = map(einsum_operands.pop, operand_indices)
158+
lhs_names, rhs_names = input_names
159+
160+
# handle cases where one side of a contracting or batch dimension is 1
161+
# but its counterpart is not.
162+
# lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
163+
# rhs_names)
164+
# rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, shape(lhs),
165+
# lhs_names)
166+
167+
lhs_counts = collections.Counter(lhs_names)
168+
rhs_counts = collections.Counter(rhs_names)
169+
170+
# sum out unique contracted indices in lhs and rhs
171+
lhs_uniques = [
172+
name
173+
for name in contracted_names
174+
if lhs_counts[name] == 1 and rhs_counts[name] == 0
175+
]
176+
lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques)
177+
178+
rhs_uniques = [
179+
name
180+
for name in contracted_names
181+
if rhs_counts[name] == 1 and lhs_counts[name] == 0
182+
]
183+
rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques)
184+
185+
# for every repeated index, contract against an identity matrix
186+
lhs, lhs_names = sum_repeats(
187+
lhs, lhs_names, lhs_counts, result_names + rhs_names
188+
)
189+
rhs, rhs_names = sum_repeats(
190+
rhs, rhs_names, rhs_counts, result_names + lhs_names
191+
)
192+
193+
lhs_or_rhs_names = set(lhs_names) | set(rhs_names)
194+
contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names]
195+
lhs_and_rhs_names = set(lhs_names) & set(rhs_names)
196+
batch_names = [x for x in result_names if x in lhs_and_rhs_names]
197+
198+
if batch_names:
199+
lhs_batch, rhs_batch = tuple(
200+
zip(*[(lhs_names.find(n), rhs_names.find(n)) for n in batch_names])
201+
)
202+
if lhs_batch or rhs_batch:
203+
raise NotImplementedError("Batch dimensions are not yet supported")
204+
else:
205+
lhs_batch = rhs_batch = ()
206+
207+
# contract using dot_general
208+
batch_names_str = "".join(batch_names)
209+
if contracted_names:
210+
lhs_cont, rhs_cont = tuple(
211+
zip(
212+
*[
213+
(lhs_names.index(n), rhs_names.index(n))
214+
for n in contracted_names
215+
]
216+
)
217+
)
218+
else:
219+
lhs_cont = rhs_cont = ()
220+
deleted_names = batch_names_str + "".join(contracted_names)
221+
remaining_lhs_names = _removechars(lhs_names, deleted_names)
222+
remaining_rhs_names = _removechars(rhs_names, deleted_names)
223+
# Try both orders of lhs and rhs, in the hope that one of them means we
224+
# don't need an explicit transpose. opt_einsum likes to contract from
225+
# right to left, so we expect (rhs,lhs) to have the best chance of not
226+
# needing a transpose.
227+
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
228+
if names == result_names:
229+
operand = tensordot(rhs, lhs, (rhs_cont, lhs_cont))
230+
else:
231+
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
232+
operand = tensordot(lhs, rhs, axes=(lhs_cont, rhs_cont))
233+
234+
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
235+
assert len(names) == len(result_names) == len(set(names))
236+
assert set(names) == set(result_names)
237+
if names != result_names:
238+
perm = tuple(names.index(name) for name in result_names)
239+
operand = transpose(operand, perm)
240+
einsum_operands.append(operand) # used in next iteration
241+
242+
[einsum_result] = einsum_operands
243+
244+
return Einsum(
245+
subscripts=subscripts,
246+
inputs=list(operands),
247+
outputs=[einsum_result],
248+
)(*operands)

tests/link/jax/test_einsum.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor
5+
import pytensor.tensor as pt
6+
7+
8+
jax = pytest.importorskip("jax")
9+
10+
11+
def test_jax_einsum():
12+
subscripts = "ij, jk, kl -> il"
13+
x = np.random.rand(3, 5)
14+
y = np.random.rand(5, 2)
15+
z = np.random.rand(2, 4)
16+
17+
shapes = ((3, 5), (5, 2), (2, 4))
18+
x_pt, y_pt, z_pt = (
19+
pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes)
20+
)
21+
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
22+
f = pytensor.function([x_pt, y_pt, z_pt], out, mode="JAX")
23+
24+
np.testing.assert_allclose(f(x, y, z), np.einsum(subscripts, x, y, z))
25+
26+
27+
@pytest.mark.xfail(raises=NotImplementedError)
28+
def test_ellipsis_einsum():
29+
subscripts = "...i,...i->..."
30+
x = np.random.rand(2, 5)
31+
y = np.random.rand(2, 5)
32+
33+
x_pt = pt.tensor("x", shape=x.shape)
34+
y_pt = pt.tensor("y", shape=y.shape)
35+
out = pt.einsum(subscripts, x_pt, y_pt)
36+
f = pytensor.function([x_pt, y_pt], out, mode="JAX")
37+
38+
np.testing.assert_allclose(f(x, y), np.einsum(subscripts, x, y))

0 commit comments

Comments
 (0)