1
1
import collections
2
2
from collections .abc import Sequence
3
- from functools import reduce
3
+ from functools import partial , reduce
4
4
from itertools import pairwise
5
+ from typing import cast
5
6
6
- from numpy .core .numeric import normalize_axis_index # type: ignore
7
+ import numpy as np
8
+ from numpy .core .numeric import ( # type: ignore
9
+ normalize_axis_index ,
10
+ normalize_axis_tuple ,
11
+ )
7
12
8
13
from pytensor .compile .builders import OpFromGraph
14
+ from pytensor .tensor import vectorize
9
15
from pytensor .tensor .basic import (
10
16
arange ,
11
17
get_vector_length ,
@@ -54,6 +60,60 @@ def _removechars(s, chars):
54
60
return s .translate (str .maketrans (dict .fromkeys (chars )))
55
61
56
62
63
+ def _batched_tensordot (
64
+ vars : tuple [TensorVariable , TensorVariable ],
65
+ axes : Sequence [Sequence [int ]], # Should be length 2,
66
+ batch_axes : Sequence [Sequence [int ]], # Should be length 2,
67
+ ) -> TensorVariable :
68
+ # Shortcut for non batched case
69
+ if not batch_axes [0 ] and not batch_axes [1 ]:
70
+ return tensordot (* vars , axes = axes )
71
+
72
+ # Normalize axes, thankfully numpy helper does not sort axis!
73
+ axes = [
74
+ normalize_axis_tuple (var_axes , var .ndim ) for var , var_axes in zip (vars , axes )
75
+ ]
76
+ batch_axes = [
77
+ normalize_axis_tuple (var_axes , var .ndim )
78
+ for var , var_axes in zip (vars , batch_axes )
79
+ ]
80
+ n_batch_axes = [len (var_batch_axes ) for var_batch_axes in batch_axes ]
81
+ if any (
82
+ var_batch_axes != tuple (range (var_n_batch_axes ))
83
+ for var_batch_axes , var_n_batch_axes in zip (batch_axes , n_batch_axes )
84
+ ):
85
+ raise NotImplementedError ("Batch dimensions must be on the left" )
86
+
87
+ lhs , rhs = vars
88
+ lhs_axes , rhs_axes = axes
89
+ lhs_n_batch_axes , rhs_n_batch_axes = n_batch_axes
90
+
91
+ # Create signature of tensordot
92
+ lhs_signature = [f"l{ i } " for i in range (lhs .type .ndim )]
93
+ rhs_signature = [f"r{ i } " for i in range (rhs .type .ndim )]
94
+ # Aligned axes get the same dimension name
95
+ for i , (lhs_axis , rhs_axis ) in enumerate (zip (lhs_axes , rhs_axes )):
96
+ lhs_signature [lhs_axis ] = rhs_signature [rhs_axis ] = f"a{ i } "
97
+ # Trim away the batch ndims
98
+ lhs_signature = lhs_signature [lhs_n_batch_axes :]
99
+ rhs_signature = rhs_signature [rhs_n_batch_axes :]
100
+ out_signature = [
101
+ lhs_dim for lhs_dim in lhs_signature if not lhs_dim .startswith ("a" )
102
+ ] + [rhs_dim for rhs_dim in rhs_signature if not rhs_dim .startswith ("a" )]
103
+ signature = f"({ ',' .join (lhs_signature )} ),({ ',' .join (rhs_signature )} )->({ ',' .join (out_signature )} )"
104
+ print (signature )
105
+ # Adjust axes for core case
106
+ core_lhs_axes = tuple (np .array (lhs_axes ) - lhs_n_batch_axes )
107
+ core_rhs_axes = tuple (np .array (rhs_axes ) - rhs_n_batch_axes )
108
+
109
+ # TODO: Make sure this looks reasonable after optimizations
110
+ # Right now we have some Blockwise(Reshape) that will slow down things!
111
+ out = vectorize (
112
+ partial (tensordot , axes = [core_lhs_axes , core_rhs_axes ]), signature = signature
113
+ )(lhs , rhs )
114
+ return cast (TensorVariable , out )
115
+
116
+
57
117
def einsum (subscripts : str , * operands ):
58
118
"""
59
119
Multiplication and summation of tensors using the Einstein summation convention.
@@ -199,8 +259,6 @@ def sum_repeats(
199
259
lhs_batch , rhs_batch = tuple (
200
260
zip (* [(lhs_names .find (n ), rhs_names .find (n )) for n in batch_names ])
201
261
)
202
- if lhs_batch or rhs_batch :
203
- raise NotImplementedError ("Batch dimensions are not yet supported" )
204
262
else :
205
263
lhs_batch = rhs_batch = ()
206
264
@@ -226,10 +284,16 @@ def sum_repeats(
226
284
# needing a transpose.
227
285
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
228
286
if names == result_names :
229
- operand = tensordot (rhs , lhs , (rhs_cont , lhs_cont ))
287
+ operand = _batched_tensordot (
288
+ (rhs , lhs ), (rhs_cont , lhs_cont ), (rhs_batch , lhs_batch )
289
+ )
230
290
else :
231
291
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
232
- operand = tensordot (lhs , rhs , axes = (lhs_cont , rhs_cont ))
292
+ operand = _batched_tensordot (
293
+ (lhs , rhs ),
294
+ axes = (lhs_cont , rhs_cont ),
295
+ batch_axes = (lhs_batch , rhs_batch ),
296
+ )
233
297
234
298
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
235
299
assert len (names ) == len (result_names ) == len (set (names ))
0 commit comments