Skip to content

Commit 2e95049

Browse files
committed
Get rid of expensive Blockwise(Reshape)
1 parent 47cd634 commit 2e95049

File tree

5 files changed

+215
-91
lines changed

5 files changed

+215
-91
lines changed

pytensor/tensor/rewriting/blockwise.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
from pytensor import Variable
12
from pytensor.compile.mode import optdb
23
from pytensor.graph import Constant, node_rewriter
34
from pytensor.graph.replace import vectorize_node
45
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
56
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
67
from pytensor.tensor.blockwise import Blockwise
8+
from pytensor.tensor.elemwise import DimShuffle
79
from pytensor.tensor.math import Dot
810
from pytensor.tensor.rewriting.basic import (
911
register_canonicalize,
1012
register_specialize,
1113
register_stabilize,
1214
)
15+
from pytensor.tensor.rewriting.uncanonicalize import local_dimshuffle_alloc
16+
from pytensor.tensor.shape import Reshape
1317
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor
1418

1519

@@ -70,7 +74,7 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
7074
Dot | Alloc | ARange | Subtensor | AdvancedSubtensor | AdvancedIncSubtensor,
7175
):
7276
# Many Dot-related rewrites (eg, all of BlasOpt) happen before specialize
73-
# These other Ops can't always be trivially vectored at runtime,
77+
# These other Ops can't always be trivially vectorized at runtime,
7478
# since their inputs may imply non-rectangular shapes.
7579
return local_useless_unbatched_blockwise.fn(fgraph, node)
7680

@@ -86,6 +90,18 @@ def _squeeze_left(x, stop_at_dim: int | None = None):
8690
return x.squeeze(axis=tuple(range(squeeze_ndim)))
8791

8892

93+
def alloc_or_expand_dims_of_alloc(var: Variable) -> bool:
94+
return var.owner and (
95+
isinstance(var.owner.op, Alloc)
96+
or (
97+
isinstance(var.owner.op, DimShuffle)
98+
and var.owner.inputs[0].owner
99+
and isinstance(var.owner.inputs[0].owner.op, Alloc)
100+
)
101+
)
102+
103+
104+
@register_canonicalize("shape_unsafe")
89105
@register_specialize("shape_unsafe")
90106
@node_rewriter([Blockwise])
91107
def local_blockwise_alloc(fgraph, node):
@@ -97,19 +113,25 @@ def local_blockwise_alloc(fgraph, node):
97113
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
98114
"""
99115

100-
if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner):
101-
return None
102-
103116
op: Blockwise = node.op # type: ignore
104117

105118
batch_ndim = op.batch_ndim(node)
106119
if not batch_ndim:
107120
return None
108121

122+
if not any(alloc_or_expand_dims_of_alloc(var) for var in node.inputs):
123+
return None
124+
109125
new_inputs = []
110126
batch_shapes = []
111127
can_push_any_alloc = False
112128
for inp, inp_sig in zip(node.inputs, op.inputs_sig):
129+
if inp.owner and isinstance(inp.owner.op, DimShuffle):
130+
# Convert DimShuffle of Alloc to Alloc
131+
new_inp = local_dimshuffle_alloc.transform(None, inp.owner)
132+
if new_inp:
133+
[inp] = new_inp
134+
113135
if inp.owner and isinstance(inp.owner.op, Alloc):
114136
# Push batch dims from Alloc
115137
value, *shape = inp.owner.inputs
@@ -167,17 +189,15 @@ def local_blockwise_alloc(fgraph, node):
167189
missing_ndim = old_out_type.ndim - new_out_type.ndim
168190
batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim]
169191
for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples
192+
if old_out_type.broadcastable[i]:
193+
continue
170194
for batch_dim in batch_dims:
171195
if batch_dim == 1:
172196
continue
197+
batch_shape[i] = batch_dim
173198
if isinstance(batch_dim, Constant):
174199
# Give preference to Constants
175-
batch_shape[i] = batch_dim
176200
break
177-
elif old_out_type.broadcastable[i]:
178-
# Only use non Constant shapes if absolutely necessary
179-
# Otherwise, we use the shape of the non-alloc output
180-
batch_shape[i] = batch_dim
181201

182202
copy_stack_trace(node.outputs, new_outs)
183203
new_outs = [
@@ -190,3 +210,29 @@ def local_blockwise_alloc(fgraph, node):
190210
]
191211
copy_stack_trace(node.outputs, new_outs)
192212
return new_outs
213+
214+
215+
@register_canonicalize
216+
@register_specialize
217+
@node_rewriter([Blockwise])
218+
def local_blockwise_reshape(fgraph, node):
219+
"""Rewrite away square Blockwise reshapes.
220+
221+
Reshape is tricky to vectorize eagerly, because a graph like
222+
`x.reshape([x.shape[0] * x.shape[1], -1])` has many operations
223+
that must be vectorized before we arrize at the reshape operation.
224+
225+
For the square Reshape case, we must wait for all the intemediate
226+
operations to be lifted as Allocs
227+
"""
228+
if not isinstance(node.op.core_op, Reshape):
229+
return None
230+
231+
x, output_shape = node.inputs
232+
batch_ndim = node.op.batch_ndim(node)
233+
if all(output_shape.type.broadcastable[:batch_ndim]):
234+
batched_shape = x.shape[:batch_ndim]
235+
core_reshape = _squeeze_left(output_shape, batch_ndim)
236+
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
237+
copy_stack_trace(node.outputs[0], new_out)
238+
return [new_out]

pytensor/tensor/rewriting/shape.py

Lines changed: 56 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from pytensor.tensor.rewriting.basic import (
3333
register_canonicalize,
3434
register_specialize,
35-
register_stabilize,
3635
register_useless,
3736
topo_constant_folding,
3837
)
@@ -749,51 +748,43 @@ def apply(self, fgraph):
749748
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
750749

751750

752-
def local_reshape_chain(op):
753-
@node_rewriter([op])
754-
def f(fgraph, node):
755-
"""
756-
Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
757-
758-
"""
759-
if not check_chain(node, op, op):
760-
return False
761-
762-
# TODO: this can permit a failing program to run by eliminating
763-
# the lower reshape
764-
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
765-
766-
# Copy over stacktrace from previous output node, as any error
767-
# in new computational graph would have been caused by last op
768-
# in the old computational graph.
769-
copy_stack_trace(node.outputs, rval)
770-
771-
# It might happen that the desired output of this node has a
772-
# broadcastable pattern that does not match that of 'rval'. This is
773-
# when originally, we were able to figure out that one of the
774-
# dimensions of the reshape is one, but some other transformation
775-
# replaced the shape by one for which this cannot be guessed.
776-
# We should try to figure out why we lost the information about this
777-
# constant value... but in the meantime, better not apply this
778-
# rewrite.
779-
if rval.type.ndim == node.outputs[0].type.ndim and all(
780-
s1 == s2
781-
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
782-
if s1 == 1 or s2 == 1
783-
):
784-
return [rval]
785-
else:
786-
return False
787-
788-
return f
751+
@register_canonicalize("shape_unsafe")
752+
@register_specialize("shape_unsafe")
753+
@node_rewriter([Reshape])
754+
def local_reshape_chain(fgraph, node):
755+
"""
756+
Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2)
789757
758+
"""
759+
if not check_chain(node, Reshape, Reshape):
760+
return False
790761

791-
register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain")
762+
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
763+
764+
# Copy over stacktrace from previous output node, as any error
765+
# in new computational graph would have been caused by last op
766+
# in the old computational graph.
767+
copy_stack_trace(node.outputs, rval)
768+
769+
# It might happen that the desired output of this node has a
770+
# broadcastable pattern that does not match that of 'rval'. This is
771+
# when originally, we were able to figure out that one of the
772+
# dimensions of the reshape is one, but some other transformation
773+
# replaced the shape by one for which this cannot be guessed.
774+
# We should try to figure out why we lost the information about this
775+
# constant value... but in the meantime, better not apply this
776+
# rewrite.
777+
if rval.type.ndim == node.outputs[0].type.ndim and all(
778+
s1 == s2
779+
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape)
780+
if s1 == 1 or s2 == 1
781+
):
782+
return [rval]
792783

793784

794-
@register_useless
795-
@register_canonicalize
796-
@register_stabilize
785+
@register_useless("shape_unsafe")
786+
@register_canonicalize("shape_unsafe")
787+
@register_specialize("shape_unsafe")
797788
@node_rewriter([Reshape])
798789
def local_useless_reshape(fgraph, node):
799790
"""Remove two kinds of useless `Reshape`.
@@ -802,24 +793,17 @@ def local_useless_reshape(fgraph, node):
802793
- Remove `Reshape` when reshaping to the shape of the input.
803794
804795
"""
805-
inp = node.inputs[0]
806-
output = node.outputs[0]
807-
output_shape = node.inputs[1]
796+
inp, output_shape = node.inputs
797+
[output] = node.outputs
808798

809799
if inp.type.ndim != output.type.ndim:
810800
return False
811801

812802
# Simple case: both input and output have a single dimension.
813-
# TODO FIXME XXX: This could hide errors if the user provides inconsistent
814-
# shapes.
815803
if (
816804
inp.type.ndim == 1
817805
and output.type.ndim == 1
818-
and all(
819-
s1 == s2
820-
for s1, s2 in zip(inp.type.shape, output.type.shape)
821-
if s1 == 1 or s2 == 1
822-
)
806+
and inp.type.broadcastable == output.type.broadcastable
823807
):
824808
return [inp]
825809

@@ -832,8 +816,15 @@ def local_useless_reshape(fgraph, node):
832816

833817
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
834818
# broadcastable and constant dimensions
835-
if output_shape.owner and isinstance(output_shape.owner.op, MakeVector):
836-
output_shape_is = output_shape.owner.inputs
819+
if isinstance(output_shape, Constant) or (
820+
output_shape.owner and isinstance(output_shape.owner.op, MakeVector)
821+
):
822+
if isinstance(output_shape, Constant):
823+
output_shape_is = [
824+
as_tensor_variable(dim, ndim=0) for dim in output_shape.data
825+
]
826+
else:
827+
output_shape_is = output_shape.owner.inputs
837828

838829
shape_feature = getattr(fgraph, "shape_feature", None)
839830

@@ -865,9 +856,9 @@ def local_useless_reshape(fgraph, node):
865856
shape_match[dim] = True
866857
continue
867858

868-
# Match 1 if input.type.shape[dim] == 1
859+
# Match constant if input.type.shape[dim] == constant
869860
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
870-
if inp.type.shape[dim] == 1 and cst_outshp_i == 1:
861+
if inp.type.shape[dim] == cst_outshp_i:
871862
shape_match[dim] = True
872863
continue
873864

@@ -881,17 +872,18 @@ def local_useless_reshape(fgraph, node):
881872
if shape_feature:
882873
inpshp_i = shape_feature.get_shape(inp, dim)
883874
if inpshp_i == outshp_i or (
884-
extract_constant(inpshp_i, only_process_constants=1)
885-
== extract_constant(outshp_i, only_process_constants=1)
875+
extract_constant(inpshp_i, only_process_constants=True)
876+
== extract_constant(outshp_i, only_process_constants=True)
886877
):
887878
shape_match[dim] = True
888879
continue
889880

890-
if all(shape_match) and nb_m1 <= 1:
881+
if nb_m1 <= 1 and all(shape_match):
882+
return [inp]
883+
884+
if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1):
891885
return [inp]
892886

893-
# TODO later: if all the shapes except one match, we may want to
894-
# consider it useless as well, like we do in the 1-dim case.
895887
return False
896888

897889

@@ -910,9 +902,8 @@ def local_reshape_to_dimshuffle(fgraph, node):
910902
-> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
911903
"""
912904
op = node.op
913-
inp = node.inputs[0]
914-
output = node.outputs[0]
915-
output_shape = node.inputs[1]
905+
inp, output_shape = node.inputs
906+
[output] = node.outputs
916907

917908
dimshuffle_new_order = []
918909
new_output_shape = []
@@ -944,7 +935,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
944935

945936

946937
@register_canonicalize
947-
@register_stabilize
938+
@register_specialize
948939
@node_rewriter([Reshape])
949940
def local_reshape_lift(fgraph, node):
950941
"""

tests/tensor/rewriting/test_blockwise.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from functools import partial
22

3-
from pytensor import function
4-
from pytensor.graph import FunctionGraph, rewrite_graph
3+
import numpy as np
4+
5+
from pytensor import Mode, function
6+
from pytensor.graph import FunctionGraph, rewrite_graph, vectorize_graph
57
from pytensor.graph.basic import equal_computations
68
from pytensor.scalar import log as scalar_log
79
from pytensor.tensor import add, alloc, matrix, tensor, tensor3
810
from pytensor.tensor.blockwise import Blockwise
911
from pytensor.tensor.elemwise import Elemwise
1012
from pytensor.tensor.nlinalg import MatrixPinv
1113
from pytensor.tensor.rewriting.blockwise import local_useless_blockwise
14+
from pytensor.tensor.shape import Reshape
1215

1316

1417
def test_useless_blockwise_of_elemwise():
@@ -118,3 +121,27 @@ def test_blockwise_alloc():
118121
out = vector_add(x, alloc(y, 5))
119122
expected_out = out
120123
assert equal([rewrite(out)], [expected_out])
124+
125+
126+
def test_blockwise_reshape():
127+
x = tensor("x", shape=(None, None, None))
128+
y = x.reshape([x.shape[0] * x.shape[1], -1])
129+
130+
new_x = tensor("x", shape=(None, None, None, None))
131+
new_y = vectorize_graph(y, {x: new_x})
132+
assert not isinstance(new_y.owner.op, Reshape)
133+
assert isinstance(new_y.owner.op, Blockwise) and isinstance(
134+
new_y.owner.op.core_op, Reshape
135+
)
136+
137+
rewritten_y = rewrite_graph(
138+
new_y, include=("canonicalize", "specialize"), clone=True
139+
)
140+
assert isinstance(rewritten_y.owner.op, Reshape)
141+
142+
no_rewrites = Mode(linker="py", optimizer=None)
143+
test_x = np.arange(5 * 4 * 3 * 2).reshape(5, 4, 3, 2)
144+
np.testing.assert_allclose(
145+
new_y.eval({"x": test_x}, mode=no_rewrites),
146+
rewritten_y.eval({"x": test_x}, mode=no_rewrites),
147+
)

0 commit comments

Comments
 (0)