Skip to content

Commit 23a5e00

Browse files
Add docstrings
1 parent 5e83d26 commit 23a5e00

File tree

1 file changed

+157
-11
lines changed

1 file changed

+157
-11
lines changed

pytensor/tensor/einsum.py

Lines changed: 157 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,42 @@ def _general_dot(
148148
def contraction_list_from_path(
149149
subscripts: str, operands: Sequence[TensorLike], path: PATH
150150
):
151-
"""TODO Docstrings
151+
"""
152+
Generate a list of contraction steps based on the provided einsum path.
153+
154+
Code adapted from einsum_opt: https://github.com/dgasmith/opt_einsum/blob/94c62a05d5ebcedd30f59c90b9926de967ed10b5/opt_einsum/contract.py#L369
155+
156+
When all shapes are known, the linked einsum_opt implementation is preferred. This implementation is used when
157+
some or all shapes are not known. As a result, contraction will (always?) be done left-to-right, pushing intermediate
158+
results to the end of the stack.
152159
153-
Code adapted from einsum_opt
160+
Parameters
161+
----------
162+
subscripts: str
163+
Einsum signature string describing the computation to be performed.
164+
165+
operands: Sequence[TensorLike]
166+
Tensors described by the subscripts.
167+
168+
path: tuple[tuple[int] | tuple[int, int]]
169+
A list of tuples, where each tuple describes the indices of the operands to be contracted, sorted in the order
170+
they should be contracted.
171+
172+
Returns
173+
-------
174+
contraction_list: list
175+
A list of tuples, where each tuple describes a contraction step. Each tuple contains the following elements:
176+
- contraction_inds: tuple[int]
177+
The indices of the operands to be contracted
178+
- idx_removed: str
179+
The indices of the contracted indices (those removed from the einsum string at this step)
180+
- einsum_str: str
181+
The einsum string for the contraction step
182+
- remaining: None
183+
The remaining indices. Included to match the output of opt_einsum.contract_path, but not used.
184+
- do_blas: None
185+
Whether to use blas to perform this step. Included to match the output of opt_einsum.contract_path,
186+
but not used.
154187
"""
155188
fake_operands = [
156189
np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands
@@ -199,9 +232,13 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
199232
200233
Code adapted from JAX: https://github.com/google/jax/blob/534d32a24d7e1efdef206188bb11ae48e9097092/jax/_src/numpy/lax_numpy.py#L5283
201234
235+
Einsum allows the user to specify a wide range of operations on tensors using the Einstein summation convention. Using
236+
this notation, many common linear algebraic operations can be succinctly described on higher order tensors.
237+
202238
Parameters
203239
----------
204240
subscripts: str
241+
Einsum signature string describing the computation to be performed.
205242
206243
operands: sequence of TensorVariable
207244
Tensors to be multiplied and summed.
@@ -210,7 +247,110 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
210247
-------
211248
TensorVariable
212249
The result of the einsum operation.
250+
251+
See Also
252+
--------
253+
pytensor.tensor.tensordot: Generalized dot product between two tensors
254+
pytensor.tensor.dot: Matrix multiplication between two tensors
255+
numpy.einsum: The numpy implementation of einsum
256+
257+
Examples
258+
--------
259+
Inputs to `pt.einsum` are a string describing the operation to be performed (the "subscripts"), and a sequence of
260+
tensors to be operated on. The string must follow the following rules:
261+
262+
1. The string gives inputs and (optionally) outputs. Inputs and outputs are separated by "->".
263+
2. The input side of the string is a comma-separated list of indices. For each comma-separated index string, there
264+
must be a corresponding tensor in the input sequence.
265+
3. For each index string, the number of dimensions in the corresponding tensor must match the number of characters
266+
in the index string.
267+
4. Indices are arbitrary strings of characters. If an index appears multiple times in the input side, it must have
268+
the same shape in each input.
269+
5. The indices on the output side must be a subset of the indices on the input side -- you cannot introduce new
270+
indices in the output.
271+
6. Elipses ("...") can be used to elide multiple indices. This is useful when you have a large number of "batch"
272+
dimensions that are not implicated in the operation.
273+
274+
Finally, two rules about these indicies govern how computation is carried out:
275+
276+
1. Repeated indices on the input side indicate how the tensor should be "aligned" for multiplication.
277+
2. Indices that appear on the input side but not the output side are summed over.
278+
279+
The operation of these rules is best understood via examples:
280+
281+
Example 1: Matrix multiplication
282+
283+
.. code-block:: python
284+
285+
import pytensor as pt
286+
A = pt.matrix("A")
287+
B = pt.matrix("B")
288+
C = pt.einsum("ij, jk -> ik", A, B)
289+
290+
This computation is equivalent to :code:`C = A @ B`. Notice that the ``j`` index is repeated on the input side of the
291+
signature, and does not appear on the output side. This indicates that the ``j`` dimension of the first tensor should be
292+
multiplied with the ``j`` dimension of the second tensor, and the resulting tensor's ``j`` dimension should be summed
293+
away.
294+
295+
Example 2: Batched matrix multiplication
296+
297+
.. code-block:: python
298+
299+
import pytensor as pt
300+
A = pt.tensor("A", shape=(None, 4, 5))
301+
B = pt.tensor("B", shape=(None, 5, 6))
302+
C = pt.einsum("bij, bjk -> bik", A, B)
303+
304+
This computation is also equivalent to :code:`C = A @ B` because of Pytensor's built-in broadcasting rules, but
305+
the einsum signature is more explicit about the batch dimensions. The ``b`` and ``j`` indices are repeated on the
306+
input side. Unlike ``j``, the ``b`` index is also present on the output side, indicating that the batch dimension
307+
should **not** be summed away. As a result, multiplication will be performed over the ``b, j`` dimensions, and then
308+
the ``j`` dimension will be summed over. The resulting tensor will have shape ``(None, 4, 6)``.
309+
310+
Example 3: Batched matrix multiplication with elipses
311+
312+
.. code-block:: python
313+
314+
import pytensor as pt
315+
A = pt.tensor("A", shape=(4, None, None, None, 5))
316+
B = pt.tensor("B", shape=(5, None, None, None, 6))
317+
C = pt.einsum("i...j, j...k -> ...ik", A, B)
318+
319+
This case is the same as above, but inputs ``A`` and ``B`` have multiple batch dimensions. To avoid writing out all
320+
of the batch dimensions (which we do not care about), we can use ellipses to elide over these dimensions. Notice
321+
also that we are not required to "sort" the input dimensions in any way. In this example, we are doing a dot
322+
between the last dimension A and the first dimension of B, which is perfectly valid.
323+
324+
Example 4: Outer product
325+
326+
.. code-block:: python
327+
328+
import pytensor as pt
329+
x = pt.tensor("x", shape=(3,))
330+
y = pt.tensor("y", shape=(4,))
331+
z = pt.einsum("i, j -> ij", x, y)
332+
333+
This computation is equivalent to :code:`pt.outer(x, y)`. Notice that no indices are repeated on the input side,
334+
and the output side has two indices. Since there are no indices to align on, the einsum operation will simply
335+
multiply the two tensors elementwise, broadcasting dimensions ``i`` and ``j``.
336+
337+
Example 5: Convolution
338+
339+
.. code-block:: python
340+
341+
import pytensor as pt
342+
x = pt.tensor("x", shape=(None, None, None, None, None, None))
343+
w = pt.tensor("w", shape=(None, None, None, None))
344+
y = pt.einsum(""bchwkt,fckt->bfhw", x, w)
345+
346+
Given a batch of images ``x`` with dimensions ``(batch, channel, height, width, kernel_size, num_filters)``
347+
and a filter ``w``, with dimensions ``(num_filters, channels, kernel_size, num_filters)``, this einsum operation
348+
computes the convolution of ``x`` with ``w``. Multiplication is aligned on the batch, num_filters, height, and width
349+
dimensions. The channel, kernel_size, and num_filters dimensions are summed over. The resulting tensor has shape
350+
``(batch, num_filters, height, width)``, reflecting the fact that information from each channel has been mixed
351+
together.
213352
"""
353+
214354
# TODO: Is this doing something clever about unknown shapes?
215355
# contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
216356
# using einsum_call=True here is an internal api for opt_einsum... sorry
@@ -223,21 +363,24 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
223363
shapes = [operand.type.shape for operand in operands]
224364

225365
if None in itertools.chain.from_iterable(shapes):
226-
# We mark optimize = False, even in cases where there is no ordering optimization to be done
227-
# because the inner graph may have to accommodate dynamic shapes.
228-
# If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
366+
# Case 1: At least one of the operands has an unknown shape. In this case, we can't use opt_einsum to optimize
367+
# the contraction order, so we just use a default path of (1,0) contractions. This will work left-to-right,
368+
# pushing intermediate results to the end of the stack.
369+
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will
370+
# match more often
371+
372+
# If shapes become known later we will likely want to rebuild the Op (unless we inline it)
229373
if len(operands) == 1:
230374
path = [(0,)]
231375
else:
232-
# Create default path of repeating (1,0) that executes left to right cyclically
233-
# with intermediate outputs being pushed to the end of the stack
234-
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
235376
path = [(1, 0) for i in range(len(operands) - 1)]
236377
contraction_list = contraction_list_from_path(subscripts, operands, path)
237-
optimize = (
238-
len(operands) <= 2
239-
) # If there are only 1 or 2 operands, there is no optimization to be done?
378+
379+
# If there are only 1 or 2 operands, there is no optimization to be done?
380+
optimize = len(operands) <= 2
240381
else:
382+
# Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
383+
# contraction order.
241384
_, contraction_list = contract_path(
242385
subscripts,
243386
*shapes,
@@ -252,6 +395,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
252395
def sum_uniques(
253396
operand: TensorVariable, names: str, uniques: list[str]
254397
) -> tuple[TensorVariable, str]:
398+
"""Reduce unique indices (those that appear only once) in a given contraction step via summing."""
255399
if uniques:
256400
axes = [names.index(name) for name in uniques]
257401
operand = operand.sum(axes)
@@ -264,6 +408,8 @@ def sum_repeats(
264408
counts: collections.Counter,
265409
keep_names: str,
266410
) -> tuple[TensorVariable, str]:
411+
"""Reduce repeated indices in a given contraction step via summation against an identity matrix."""
412+
267413
for name, count in counts.items():
268414
if count > 1:
269415
axes = [i for i, n in enumerate(names) if n == name]

0 commit comments

Comments
 (0)