Skip to content

Commit 086c717

Browse files
committed
Speedup python blockwise
1 parent 9a49994 commit 086c717

File tree

4 files changed

+198
-88
lines changed

4 files changed

+198
-88
lines changed

pytensor/graph/op.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def make_py_thunk(
502502
self,
503503
node: Apply,
504504
storage_map: StorageMapType,
505-
compute_map: ComputeMapType,
505+
compute_map: ComputeMapType | None,
506506
no_recycling: list[Variable],
507507
debug: bool = False,
508508
) -> ThunkType:
@@ -513,25 +513,38 @@ def make_py_thunk(
513513
"""
514514
node_input_storage = [storage_map[r] for r in node.inputs]
515515
node_output_storage = [storage_map[r] for r in node.outputs]
516-
node_compute_map = [compute_map[r] for r in node.outputs]
517516

518517
if debug and hasattr(self, "debug_perform"):
519518
p = node.op.debug_perform
520519
else:
521520
p = node.op.perform
522521

523-
@is_thunk_type
524-
def rval(
525-
p=p,
526-
i=node_input_storage,
527-
o=node_output_storage,
528-
n=node,
529-
cm=node_compute_map,
530-
):
531-
r = p(n, [x[0] for x in i], o)
532-
for entry in cm:
533-
entry[0] = True
534-
return r
522+
if compute_map is None:
523+
524+
@is_thunk_type
525+
def rval(
526+
p=p,
527+
i=node_input_storage,
528+
o=node_output_storage,
529+
n=node,
530+
):
531+
return p(n, [x[0] for x in i], o)
532+
533+
else:
534+
node_compute_map = [compute_map[r] for r in node.outputs]
535+
536+
@is_thunk_type
537+
def rval(
538+
p=p,
539+
i=node_input_storage,
540+
o=node_output_storage,
541+
n=node,
542+
cm=node_compute_map,
543+
):
544+
r = p(n, [x[0] for x in i], o)
545+
for entry in cm:
546+
entry[0] = True
547+
return r
535548

536549
rval.inputs = node_input_storage
537550
rval.outputs = node_output_storage

pytensor/link/c/op.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def make_c_thunk(
3939
self,
4040
node: Apply,
4141
storage_map: StorageMapType,
42-
compute_map: ComputeMapType,
42+
compute_map: ComputeMapType | None,
4343
no_recycling: Collection[Variable],
4444
) -> CThunkWrapperType:
4545
"""Create a thunk for a C implementation.
@@ -86,11 +86,17 @@ def is_f16(t):
8686
)
8787
thunk, node_input_filters, node_output_filters = outputs
8888

89-
@is_cthunk_wrapper_type
90-
def rval():
91-
thunk()
92-
for o in node.outputs:
93-
compute_map[o][0] = True
89+
if compute_map is None:
90+
rval = is_cthunk_wrapper_type(thunk)
91+
92+
else:
93+
cm_entries = [compute_map[o] for o in node.outputs]
94+
95+
@is_cthunk_wrapper_type
96+
def rval(thunk=thunk, cm_entries=cm_entries):
97+
thunk()
98+
for entry in cm_entries:
99+
entry[0] = True
94100

95101
rval.thunk = thunk
96102
rval.cthunk = thunk.cthunk

pytensor/tensor/blockwise.py

Lines changed: 137 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from collections.abc import Sequence
1+
from collections.abc import Callable, Sequence
22
from typing import Any, cast
33

44
import numpy as np
5+
from numpy import broadcast_shapes, empty
56

67
from pytensor import config
78
from pytensor.compile.builders import OpFromGraph
@@ -22,12 +23,110 @@
2223
from pytensor.tensor.utils import (
2324
_parse_gufunc_signature,
2425
broadcast_static_dim_lengths,
26+
faster_broadcast_to,
27+
faster_ndindex,
2528
import_func_from_string,
2629
safe_signature,
2730
)
2831
from pytensor.tensor.variable import TensorVariable
2932

3033

34+
def _vectorize_node_perform(
35+
core_node, batch_bcast_patterns, batch_ndim: int, impl=None
36+
):
37+
"""self,
38+
node: Apply,
39+
storage_map: StorageMapType,
40+
compute_map: ComputeMapType,
41+
no_recycling: list[Variable],
42+
impl: str | None = None,"""
43+
44+
storage_map = {var: [None] for var in core_node.inputs + core_node.outputs}
45+
core_thunk = core_node.op.make_thunk(core_node, storage_map, None, [], impl=impl)
46+
single_in = len(core_node.inputs) == 1
47+
core_input_storage = [storage_map[inp] for inp in core_node.inputs]
48+
core_output_storage = [storage_map[out] for out in core_node.outputs]
49+
core_storage = core_input_storage + core_output_storage
50+
51+
def vectorized_perform(
52+
*args,
53+
batch_bcast_patterns=batch_bcast_patterns,
54+
batch_ndim=batch_ndim,
55+
single_in=single_in,
56+
core_thunk=core_thunk,
57+
core_input_storage=core_input_storage,
58+
core_output_storage=core_output_storage,
59+
core_storage=core_storage,
60+
):
61+
if single_in:
62+
batch_shape = args[0].shape[:batch_ndim]
63+
else:
64+
_check_runtime_broadcast(args, batch_bcast_patterns, batch_ndim)
65+
batch_shape = broadcast_shapes(*(arg.shape[:batch_ndim] for arg in args))
66+
args = list(args)
67+
for i, arg in enumerate(args):
68+
if arg.shape[:batch_ndim] != batch_shape:
69+
args[i] = faster_broadcast_to(
70+
arg, batch_shape + arg.shape[batch_ndim:]
71+
)
72+
73+
ndindex_iterator = faster_ndindex(*batch_shape)
74+
# Call once to get the output shapes
75+
try:
76+
# TODO: Pass core shape as input like BlockwiseWithCoreShape does?
77+
index0 = next(ndindex_iterator)
78+
except StopIteration:
79+
raise NotImplementedError("vectorize with zero size not implemented")
80+
else:
81+
for core_input, arg in zip(core_input_storage, args):
82+
core_input[0] = np.asarray(arg[index0])
83+
core_thunk()
84+
outputs = tuple(
85+
empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype)
86+
for core_output in core_output_storage
87+
)
88+
for output, core_output in zip(outputs, core_output_storage):
89+
output[index0] = core_output[0]
90+
91+
for index in ndindex_iterator:
92+
for core_input, arg in zip(core_input_storage, args):
93+
core_input[0] = np.asarray(arg[index])
94+
core_thunk()
95+
for output, core_output in zip(outputs, core_output_storage):
96+
output[index] = core_output[0]
97+
98+
# Clear storage
99+
for core_val in core_storage:
100+
core_val[0] = None
101+
return outputs
102+
103+
return vectorized_perform
104+
105+
106+
def _check_runtime_broadcast(numerical_inputs, batch_bcast_patterns, batch_ndim):
107+
# strict=None because we are in a hot loop
108+
# We zip together the dimension lengths of each input and their broadcast patterns
109+
for dim_lengths_and_bcast in zip(
110+
*[
111+
zip(input.shape[:batch_ndim], batch_bcast_pattern)
112+
for input, batch_bcast_pattern in zip(
113+
numerical_inputs, batch_bcast_patterns
114+
)
115+
],
116+
):
117+
# If for any dimension where an entry has dim_length != 1,
118+
# and another a dim_length of 1 and broadcastable=False, we have runtime broadcasting.
119+
if (
120+
any(d != 1 for d, _ in dim_lengths_and_bcast)
121+
and (1, False) in dim_lengths_and_bcast
122+
):
123+
raise ValueError(
124+
"Runtime broadcasting not allowed. "
125+
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
126+
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
127+
)
128+
129+
31130
class Blockwise(Op):
32131
"""Generalizes a core `Op` to work with batched dimensions.
33132
@@ -308,91 +407,62 @@ def L_op(self, inputs, outs, ograds):
308407

309408
return rval
310409

311-
def _create_node_gufunc(self, node) -> None:
410+
def _create_node_gufunc(self, node: Apply, impl) -> Callable:
312411
"""Define (or retrieve) the node gufunc used in `perform`.
313412
314413
If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
315414
Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.
316415
317416
The gufunc is stored in the tag of the node.
318417
"""
319-
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)
320-
321-
if gufunc_spec is not None:
322-
gufunc = import_func_from_string(gufunc_spec[0])
323-
if gufunc is None:
418+
batch_ndim = self.batch_ndim(node)
419+
batch_bcast_patterns = [
420+
inp.type.broadcastable[:batch_ndim] for inp in node.inputs
421+
]
422+
if (
423+
gufunc_spec := self.gufunc_spec
424+
or getattr(self.core_op, "gufunc_spec", None)
425+
) is not None:
426+
core_func = import_func_from_string(gufunc_spec[0])
427+
if core_func is None:
324428
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
325429

326-
else:
327-
# Wrap core_op perform method in numpy vectorize
328-
n_outs = len(self.outputs_sig)
329-
core_node = self._create_dummy_core_node(node.inputs)
330-
inner_outputs_storage = [[None] for _ in range(n_outs)]
331-
332-
def core_func(
333-
*inner_inputs,
334-
core_node=core_node,
335-
inner_outputs_storage=inner_outputs_storage,
336-
):
337-
self.core_op.perform(
338-
core_node,
339-
[np.asarray(inp) for inp in inner_inputs],
340-
inner_outputs_storage,
341-
)
430+
if len(node.outputs) == 1:
342431

343-
if n_outs == 1:
344-
return inner_outputs_storage[0][0]
345-
else:
346-
return tuple(r[0] for r in inner_outputs_storage)
432+
def gufunc(*inputs):
433+
_check_runtime_broadcast(inputs, batch_bcast_patterns, batch_ndim)
434+
return (core_func(*inputs),)
435+
else:
347436

348-
gufunc = np.vectorize(core_func, signature=self.signature)
437+
def gufunc(*inputs):
438+
_check_runtime_broadcast(inputs, batch_bcast_patterns, batch_ndim)
439+
return core_func(*inputs)
440+
else:
441+
core_node = self._create_dummy_core_node(node.inputs) # type: ignore
442+
gufunc = _vectorize_node_perform(
443+
core_node,
444+
batch_bcast_patterns=batch_bcast_patterns,
445+
batch_ndim=self.batch_ndim(node),
446+
impl=impl,
447+
)
349448

350-
node.tag.gufunc = gufunc
449+
return gufunc
351450

352451
def _check_runtime_broadcast(self, node, inputs):
353452
batch_ndim = self.batch_ndim(node)
453+
batch_bcast = [inp.type.broadcastable[:batch_ndim] for inp in inputs]
454+
_check_runtime_broadcast(inputs, batch_bcast, batch_ndim)
354455

355-
# strict=False because we are in a hot loop
356-
for dims_and_bcast in zip(
357-
*[
358-
zip(
359-
input.shape[:batch_ndim],
360-
sinput.type.broadcastable[:batch_ndim],
361-
strict=False,
362-
)
363-
for input, sinput in zip(inputs, node.inputs, strict=False)
364-
],
365-
strict=False,
366-
):
367-
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
368-
raise ValueError(
369-
"Runtime broadcasting not allowed. "
370-
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
371-
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
372-
)
456+
def prepare_node(self, node, storage_map, compute_map, impl=None):
457+
node.tag.gufunc = self._create_node_gufunc(node, impl=impl)
373458

374459
def perform(self, node, inputs, output_storage):
375-
gufunc = getattr(node.tag, "gufunc", None)
376-
377-
if gufunc is None:
378-
# Cache it once per node
379-
self._create_node_gufunc(node)
460+
try:
380461
gufunc = node.tag.gufunc
381-
382-
self._check_runtime_broadcast(node, inputs)
383-
384-
res = gufunc(*inputs)
385-
if not isinstance(res, tuple):
386-
res = (res,)
387-
388-
# strict=False because we are in a hot loop
389-
for node_out, out_storage, r in zip(
390-
node.outputs, output_storage, res, strict=False
391-
):
392-
out_dtype = getattr(node_out, "dtype", None)
393-
if out_dtype and out_dtype != r.dtype:
394-
r = np.asarray(r, dtype=out_dtype)
395-
out_storage[0] = r
462+
except AttributeError:
463+
gufunc = node.tag.gufunc = self._create_node_gufunc(node, impl=None)
464+
for out_storage, result in zip(output_storage, gufunc(*inputs)):
465+
out_storage[0] = result
396466

397467
def __str__(self):
398468
if self.name is None:

tests/tensor/test_blockwise.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
from pytensor.graph import Apply, Op
1313
from pytensor.graph.replace import vectorize_node
1414
from pytensor.raise_op import assert_op
15-
from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector
15+
from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
1616
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
1717
from pytensor.tensor.nlinalg import MatrixInverse
1818
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
19+
from pytensor.tensor.signal import convolve1d
1920
from pytensor.tensor.slinalg import (
2021
Cholesky,
2122
Solve,
@@ -484,6 +485,26 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
484485
benchmark(fn, *test_values)
485486

486487

488+
def test_small_blackwise_performance(benchmark):
489+
a = dmatrix(shape=(7, 128))
490+
b = dmatrix(shape=(7, 20))
491+
out = convolve1d(a, b, mode="valid")
492+
fn = pytensor.function([a, b], out, trust_input=True)
493+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
494+
495+
rng = np.random.default_rng(495)
496+
a_test = rng.normal(size=a.type.shape)
497+
b_test = rng.normal(size=b.type.shape)
498+
np.testing.assert_allclose(
499+
fn(a_test, b_test),
500+
[
501+
np.convolve(a_test[i], b_test[i], mode="valid")
502+
for i in range(a_test.shape[0])
503+
],
504+
)
505+
benchmark(fn, a_test, b_test)
506+
507+
487508
def test_cop_with_params():
488509
matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)")
489510

0 commit comments

Comments
 (0)