Skip to content

Commit e0febc4

Browse files
committed
Flag Ops whose output types depend on input values
These nodes must always be rebuilt in non-strict mode
1 parent 0cb5fdb commit e0febc4

File tree

10 files changed

+109
-2
lines changed

10 files changed

+109
-2
lines changed

pytensor/graph/basic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,13 @@ def clone_with_new_inputs(
266266
assert isinstance(inputs, (list, tuple))
267267
remake_node = False
268268
new_inputs: List["Variable"] = list(inputs)
269+
270+
# Some Ops like Alloc require the node to always be rebuilt in non-strict mode
271+
# as the output type depends on the input values and not just their types
272+
output_type_depends_on_input_value = self.op._output_type_depends_on_input_value
273+
269274
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
270-
if curr.type != new.type:
275+
if (curr.type != new.type) or output_type_depends_on_input_value:
271276
if strict:
272277
new_i = curr.type.filter_variable(new)
273278
new_inputs[i] = new_i

pytensor/graph/op.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,15 @@ class Op(MetaObject):
207207
otypes: Optional[Sequence["Type"]] = None
208208
params_type: Optional[ParamsType] = None
209209

210+
_output_type_depends_on_input_value = False
211+
"""
212+
Whether the static output type depends on the inferred value of one of the inputs.
213+
(e.g, via constant folding or static shape inference).
214+
215+
This information is needed when rebuilding a graph with new inputs,
216+
as nodes with these Ops must be rebuilt even if the input types haven't changed.
217+
"""
218+
210219
def make_node(self, *inputs: Variable) -> Apply:
211220
"""Construct an `Apply` node that represent the application of this operation to the given inputs.
212221

pytensor/tensor/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,8 @@ class Alloc(COp):
14181418
"""
14191419

14201420
_f16_ok = True
1421+
_output_type_depends_on_input_value = True
1422+
14211423
__props__ = ()
14221424

14231425
def make_node(self, value, *shape):
@@ -3819,6 +3821,8 @@ def perform(self, node, inputs, outputs):
38193821
class AllocEmpty(COp):
38203822
"""Implement Alloc on the cpu, but without initializing memory."""
38213823

3824+
_output_type_depends_on_input_value = True
3825+
38223826
__props__ = ("dtype",)
38233827
params_type = ParamsType(typecode=int32)
38243828

pytensor/tensor/extra_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,6 +1561,8 @@ def broadcast_shape_iter(
15611561
class BroadcastTo(COp):
15621562
"""An `Op` for `numpy.broadcast_to`."""
15631563

1564+
_output_type_depends_on_input_value = True
1565+
15641566
__props__ = ()
15651567

15661568
view_map = {0: [0]}

pytensor/tensor/random/op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class RandomVariable(Op):
9191
9292
"""
9393

94+
_output_type_depends_on_input_value = True
95+
9496
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace")
9597
default_output = 1
9698

pytensor/tensor/shape.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ class SpecifyShape(COp):
388388
view_map = {0: [0]}
389389
__props__ = ()
390390
_f16_ok = True
391+
_output_type_depends_on_input_value = True
391392

392393
def make_node(self, x, *shape):
393394
from pytensor.tensor.basic import get_underlying_scalar_constant_value
@@ -587,6 +588,7 @@ class Reshape(COp):
587588

588589
view_map = {0: [0]} # output 0 is potentially aliased to inputs [0]
589590
_f16_ok = True
591+
_output_type_depends_on_input_value = True
590592

591593
check_input = False
592594
__props__ = ("ndim",)

tests/tensor/random/test_basic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pytensor.graph.basic import Constant, Variable, graph_inputs
1515
from pytensor.graph.fg import FunctionGraph
1616
from pytensor.graph.op import get_test_value
17+
from pytensor.graph.replace import clone_replace
1718
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1819
from pytensor.tensor.random.basic import (
1920
bernoulli,
@@ -57,7 +58,7 @@
5758
weibull,
5859
)
5960
from pytensor.tensor.rewriting.shape import ShapeFeature
60-
from pytensor.tensor.type import iscalar, scalar, tensor
61+
from pytensor.tensor.type import iscalar, scalar, tensor, vector
6162
from tests.unittest_tools import create_pytensor_param
6263

6364

@@ -1422,3 +1423,19 @@ def test_pickle():
14221423
a_unpkl = pickle.loads(a_pkl)
14231424

14241425
assert a_unpkl.owner.op._props() == sample_a.owner.op._props()
1426+
1427+
1428+
def test_rebuild():
1429+
x = vector(shape=(50,))
1430+
x_test = np.zeros((50,), dtype=config.floatX)
1431+
y = normal(size=x.shape)
1432+
assert y.type.shape == (50,)
1433+
assert y.shape.eval({x: x_test}) == (50,)
1434+
assert y.eval({x: x_test}).shape == (50,)
1435+
1436+
x_new = vector(shape=(100,))
1437+
x_new_test = np.zeros((100,), dtype=config.floatX)
1438+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
1439+
assert y_new.type.shape == (100,)
1440+
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
1441+
assert y_new.eval({x_new: x_new_test}).shape == (100,)

tests/tensor/test_basic.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytensor.gradient import grad, hessian
1717
from pytensor.graph.basic import Apply
1818
from pytensor.graph.op import Op
19+
from pytensor.graph.replace import clone_replace
1920
from pytensor.misc.safe_asarray import _asarray
2021
from pytensor.raise_op import Assert
2122
from pytensor.scalar import autocast_float, autocast_float_as
@@ -818,6 +819,22 @@ def test_full(self):
818819
res = pytensor.function([], full_at, mode=self.mode)()
819820
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64"))
820821

822+
@pytest.mark.parametrize("func", (at.zeros, at.empty))
823+
def test_rebuild(self, func):
824+
x = vector(shape=(50,))
825+
x_test = np.zeros((50,), dtype=config.floatX)
826+
y = func(x.shape)
827+
assert y.type.shape == (50,)
828+
assert y.shape.eval({x: x_test}) == (50,)
829+
assert y.eval({x: x_test}).shape == (50,)
830+
831+
x_new = vector(shape=(100,))
832+
x_new_test = np.zeros((100,), dtype=config.floatX)
833+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
834+
assert y_new.type.shape == (100,)
835+
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
836+
assert y_new.eval({x_new: x_new_test}).shape == (100,)
837+
821838

822839
def test_infer_shape():
823840
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):

tests/tensor/test_extra_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.compile.mode import Mode
1010
from pytensor.configdefaults import config
1111
from pytensor.graph.basic import Constant, applys_between
12+
from pytensor.graph.replace import clone_replace
1213
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1314
from pytensor.raise_op import Assert
1415
from pytensor.tensor.elemwise import DimShuffle
@@ -1399,6 +1400,22 @@ def test_inplace(self):
13991400

14001401
assert advincsub_node.op.inplace is False
14011402

1403+
def test_rebuild(self):
1404+
x = vector(shape=(50,))
1405+
x_test = np.zeros((50,), dtype=config.floatX)
1406+
i = 0
1407+
y = broadcast_to(i, x.shape)
1408+
assert y.type.shape == (50,)
1409+
assert y.shape.eval({x: x_test}) == (50,)
1410+
assert y.eval({x: x_test}).shape == (50,)
1411+
1412+
x_new = vector(shape=(100,))
1413+
x_new_test = np.zeros((100,), dtype=config.floatX)
1414+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
1415+
assert y_new.type.shape == (100,)
1416+
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
1417+
assert y_new.eval({x_new: x_new_test}).shape == (100,)
1418+
14021419

14031420
def test_broadcast_arrays():
14041421
x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix()

tests/tensor/test_shape.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pytensor.configdefaults import config
88
from pytensor.graph.basic import Variable
99
from pytensor.graph.fg import FunctionGraph
10+
from pytensor.graph.replace import clone_replace
1011
from pytensor.graph.type import Type
1112
from pytensor.misc.safe_asarray import _asarray
1213
from pytensor.scalar.basic import ScalarConstant
@@ -337,6 +338,21 @@ def test_more_shapes(self):
337338
Reshape,
338339
)
339340

341+
def test_rebuild(self):
342+
x = as_tensor_variable(50)
343+
i = vector("i")
344+
i_test = np.zeros((100,), dtype=config.floatX)
345+
y = reshape(i, (100 // x, x))
346+
assert y.type.shape == (2, 50)
347+
assert tuple(y.shape.eval({i: i_test})) == (2, 50)
348+
assert y.eval({i: i_test}).shape == (2, 50)
349+
350+
x_new = as_tensor_variable(25)
351+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
352+
assert y_new.type.shape == (4, 25)
353+
assert tuple(y_new.shape.eval({i: i_test})) == (4, 25)
354+
assert y_new.eval({i: i_test}).shape == (4, 25)
355+
340356

341357
def test_shape_i_hash():
342358
assert isinstance(Shape_i(np.int64(1)).__hash__(), int)
@@ -524,6 +540,22 @@ def test_specify_shape_in_grad(self):
524540
z_grad = grad(z.sum(), wrt=x)
525541
assert isinstance(z_grad.owner.op, SpecifyShape)
526542

543+
def test_rebuild(self):
544+
x = as_tensor_variable(50)
545+
i = matrix("i")
546+
i_test = np.zeros((4, 50), dtype=config.floatX)
547+
y = specify_shape(i, (None, x))
548+
assert y.type.shape == (None, 50)
549+
assert tuple(y.shape.eval({i: i_test})) == (4, 50)
550+
assert y.eval({i: i_test}).shape == (4, 50)
551+
552+
x_new = as_tensor_variable(100)
553+
i_test = np.zeros((4, 100), dtype=config.floatX)
554+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
555+
assert y_new.type.shape == (None, 100)
556+
assert tuple(y_new.shape.eval({i: i_test})) == (4, 100)
557+
assert y_new.eval({i: i_test}).shape == (4, 100)
558+
527559

528560
class TestSpecifyBroadcastable:
529561
def test_basic(self):

0 commit comments

Comments
 (0)