@@ -148,9 +148,42 @@ def _general_dot(
148
148
def contraction_list_from_path (
149
149
subscripts : str , operands : Sequence [TensorLike ], path : PATH
150
150
):
151
- """TODO Docstrings
151
+ """
152
+ Generate a list of contraction steps based on the provided einsum path.
153
+
154
+ Code adapted from einsum_opt: https://github.com/dgasmith/opt_einsum/blob/94c62a05d5ebcedd30f59c90b9926de967ed10b5/opt_einsum/contract.py#L369
155
+
156
+ When all shapes are known, the linked einsum_opt implementation is preferred. This implementation is used when
157
+ some or all shapes are not known. As a result, contraction will (always?) be done left-to-right, pushing intermediate
158
+ results to the end of the stack.
152
159
153
- Code adapted from einsum_opt
160
+ Parameters
161
+ ----------
162
+ subscripts: str
163
+ Einsum signature string describing the computation to be performed.
164
+
165
+ operands: Sequence[TensorLike]
166
+ Tensors described by the subscripts.
167
+
168
+ path: tuple[tuple[int] | tuple[int, int]]
169
+ A list of tuples, where each tuple describes the indices of the operands to be contracted, sorted in the order
170
+ they should be contracted.
171
+
172
+ Returns
173
+ -------
174
+ contraction_list: list
175
+ A list of tuples, where each tuple describes a contraction step. Each tuple contains the following elements:
176
+ - contraction_inds: tuple[int]
177
+ The indices of the operands to be contracted
178
+ - idx_removed: str
179
+ The indices of the contracted indices (those removed from the einsum string at this step)
180
+ - einsum_str: str
181
+ The einsum string for the contraction step
182
+ - remaining: None
183
+ The remaining indices. Included to match the output of opt_einsum.contract_path, but not used.
184
+ - do_blas: None
185
+ Whether to use blas to perform this step. Included to match the output of opt_einsum.contract_path,
186
+ but not used.
154
187
"""
155
188
fake_operands = [
156
189
np .zeros ([1 if dim == 1 else 0 for dim in x .type .shape ]) for x in operands
@@ -199,9 +232,13 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
199
232
200
233
Code adapted from JAX: https://github.com/google/jax/blob/534d32a24d7e1efdef206188bb11ae48e9097092/jax/_src/numpy/lax_numpy.py#L5283
201
234
235
+ Einsum allows the user to specify a wide range of operations on tensors using the Einstein summation convention. Using
236
+ this notation, many common linear algebraic operations can be succinctly described on higher order tensors.
237
+
202
238
Parameters
203
239
----------
204
240
subscripts: str
241
+ Einsum signature string describing the computation to be performed.
205
242
206
243
operands: sequence of TensorVariable
207
244
Tensors to be multiplied and summed.
@@ -210,7 +247,110 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
210
247
-------
211
248
TensorVariable
212
249
The result of the einsum operation.
250
+
251
+ See Also
252
+ --------
253
+ pytensor.tensor.tensordot: Generalized dot product between two tensors
254
+ pytensor.tensor.dot: Matrix multiplication between two tensors
255
+ numpy.einsum: The numpy implementation of einsum
256
+
257
+ Examples
258
+ --------
259
+ Inputs to `pt.einsum` are a string describing the operation to be performed (the "subscripts"), and a sequence of
260
+ tensors to be operated on. The string must follow the following rules:
261
+
262
+ 1. The string gives inputs and (optionally) outputs. Inputs and outputs are separated by "->".
263
+ 2. The input side of the string is a comma-separated list of indices. For each comma-separated index string, there
264
+ must be a corresponding tensor in the input sequence.
265
+ 3. For each index string, the number of dimensions in the corresponding tensor must match the number of characters
266
+ in the index string.
267
+ 4. Indices are arbitrary strings of characters. If an index appears multiple times in the input side, it must have
268
+ the same shape in each input.
269
+ 5. The indices on the output side must be a subset of the indices on the input side -- you cannot introduce new
270
+ indices in the output.
271
+ 6. Elipses ("...") can be used to elide multiple indices. This is useful when you have a large number of "batch"
272
+ dimensions that are not implicated in the operation.
273
+
274
+ Finally, two rules about these indicies govern how computation is carried out:
275
+
276
+ 1. Repeated indices on the input side indicate how the tensor should be "aligned" for multiplication.
277
+ 2. Indices that appear on the input side but not the output side are summed over.
278
+
279
+ The operation of these rules is best understood via examples:
280
+
281
+ Example 1: Matrix multiplication
282
+
283
+ .. code-block:: python
284
+
285
+ import pytensor as pt
286
+ A = pt.matrix("A")
287
+ B = pt.matrix("B")
288
+ C = pt.einsum("ij, jk -> ik", A, B)
289
+
290
+ This computation is equivalent to :code:`C = A @ B`. Notice that the ``j`` index is repeated on the input side of the
291
+ signature, and does not appear on the output side. This indicates that the ``j`` dimension of the first tensor should be
292
+ multiplied with the ``j`` dimension of the second tensor, and the resulting tensor's ``j`` dimension should be summed
293
+ away.
294
+
295
+ Example 2: Batched matrix multiplication
296
+
297
+ .. code-block:: python
298
+
299
+ import pytensor as pt
300
+ A = pt.tensor("A", shape=(None, 4, 5))
301
+ B = pt.tensor("B", shape=(None, 5, 6))
302
+ C = pt.einsum("bij, bjk -> bik", A, B)
303
+
304
+ This computation is also equivalent to :code:`C = A @ B` because of Pytensor's built-in broadcasting rules, but
305
+ the einsum signature is more explicit about the batch dimensions. The ``b`` and ``j`` indices are repeated on the
306
+ input side. Unlike ``j``, the ``b`` index is also present on the output side, indicating that the batch dimension
307
+ should **not** be summed away. As a result, multiplication will be performed over the ``b, j`` dimensions, and then
308
+ the ``j`` dimension will be summed over. The resulting tensor will have shape ``(None, 4, 6)``.
309
+
310
+ Example 3: Batched matrix multiplication with elipses
311
+
312
+ .. code-block:: python
313
+
314
+ import pytensor as pt
315
+ A = pt.tensor("A", shape=(4, None, None, None, 5))
316
+ B = pt.tensor("B", shape=(5, None, None, None, 6))
317
+ C = pt.einsum("i...j, j...k -> ...ik", A, B)
318
+
319
+ This case is the same as above, but inputs ``A`` and ``B`` have multiple batch dimensions. To avoid writing out all
320
+ of the batch dimensions (which we do not care about), we can use ellipses to elide over these dimensions. Notice
321
+ also that we are not required to "sort" the input dimensions in any way. In this example, we are doing a dot
322
+ between the last dimension A and the first dimension of B, which is perfectly valid.
323
+
324
+ Example 4: Outer product
325
+
326
+ .. code-block:: python
327
+
328
+ import pytensor as pt
329
+ x = pt.tensor("x", shape=(3,))
330
+ y = pt.tensor("y", shape=(4,))
331
+ z = pt.einsum("i, j -> ij", x, y)
332
+
333
+ This computation is equivalent to :code:`pt.outer(x, y)`. Notice that no indices are repeated on the input side,
334
+ and the output side has two indices. Since there are no indices to align on, the einsum operation will simply
335
+ multiply the two tensors elementwise, broadcasting dimensions ``i`` and ``j``.
336
+
337
+ Example 5: Convolution
338
+
339
+ .. code-block:: python
340
+
341
+ import pytensor as pt
342
+ x = pt.tensor("x", shape=(None, None, None, None, None, None))
343
+ w = pt.tensor("w", shape=(None, None, None, None))
344
+ y = pt.einsum(""bchwkt,fckt->bfhw", x, w)
345
+
346
+ Given a batch of images ``x`` with dimensions ``(batch, channel, height, width, kernel_size, num_filters)``
347
+ and a filter ``w``, with dimensions ``(num_filters, channels, kernel_size, num_filters)``, this einsum operation
348
+ computes the convolution of ``x`` with ``w``. Multiplication is aligned on the batch, num_filters, height, and width
349
+ dimensions. The channel, kernel_size, and num_filters dimensions are summed over. The resulting tensor has shape
350
+ ``(batch, num_filters, height, width)``, reflecting the fact that information from each channel has been mixed
351
+ together.
213
352
"""
353
+
214
354
# TODO: Is this doing something clever about unknown shapes?
215
355
# contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
216
356
# using einsum_call=True here is an internal api for opt_einsum... sorry
@@ -223,21 +363,24 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
223
363
shapes = [operand .type .shape for operand in operands ]
224
364
225
365
if None in itertools .chain .from_iterable (shapes ):
226
- # We mark optimize = False, even in cases where there is no ordering optimization to be done
227
- # because the inner graph may have to accommodate dynamic shapes.
228
- # If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
366
+ # Case 1: At least one of the operands has an unknown shape. In this case, we can't use opt_einsum to optimize
367
+ # the contraction order, so we just use a default path of (1,0) contractions. This will work left-to-right,
368
+ # pushing intermediate results to the end of the stack.
369
+ # We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will
370
+ # match more often
371
+
372
+ # If shapes become known later we will likely want to rebuild the Op (unless we inline it)
229
373
if len (operands ) == 1 :
230
374
path = [(0 ,)]
231
375
else :
232
- # Create default path of repeating (1,0) that executes left to right cyclically
233
- # with intermediate outputs being pushed to the end of the stack
234
- # We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
235
376
path = [(1 , 0 ) for i in range (len (operands ) - 1 )]
236
377
contraction_list = contraction_list_from_path (subscripts , operands , path )
237
- optimize = (
238
- len ( operands ) <= 2
239
- ) # If there are only 1 or 2 operands, there is no optimization to be done?
378
+
379
+ # If there are only 1 or 2 operands, there is no optimization to be done?
380
+ optimize = len ( operands ) <= 2
240
381
else :
382
+ # Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
383
+ # contraction order.
241
384
_ , contraction_list = contract_path (
242
385
subscripts ,
243
386
* shapes ,
@@ -252,6 +395,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
252
395
def sum_uniques (
253
396
operand : TensorVariable , names : str , uniques : list [str ]
254
397
) -> tuple [TensorVariable , str ]:
398
+ """Reduce unique indices (those that appear only once) in a given contraction step via summing."""
255
399
if uniques :
256
400
axes = [names .index (name ) for name in uniques ]
257
401
operand = operand .sum (axes )
@@ -264,6 +408,8 @@ def sum_repeats(
264
408
counts : collections .Counter ,
265
409
keep_names : str ,
266
410
) -> tuple [TensorVariable , str ]:
411
+ """Reduce repeated indices in a given contraction step via summation against an identity matrix."""
412
+
267
413
for name , count in counts .items ():
268
414
if count > 1 :
269
415
axes = [i for i , n in enumerate (names ) if n == name ]
0 commit comments