Skip to content

Commit 24c20dc

Browse files
committed
Flag Ops whose output types depend on input values
These nodes must always be rebuilt in non-strict mode
1 parent c36f731 commit 24c20dc

File tree

9 files changed

+102
-2
lines changed

9 files changed

+102
-2
lines changed

pytensor/graph/basic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,15 @@ def clone_with_new_inputs(
266266
assert isinstance(inputs, (list, tuple))
267267
remake_node = False
268268
new_inputs: List["Variable"] = list(inputs)
269+
270+
# Some Ops like Alloc require the node to always be rebuilt in non-strict mode
271+
# as the output type depends on the input values and not just their types
272+
output_type_depends_on_input_value = getattr(
273+
self.op, "_output_type_depends_on_input_value", False
274+
)
275+
269276
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
270-
if curr.type != new.type:
277+
if (curr.type != new.type) or output_type_depends_on_input_value:
271278
if strict:
272279
new_i = curr.type.filter_variable(new)
273280
new_inputs[i] = new_i

pytensor/tensor/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,8 @@ class Alloc(COp):
14181418
"""
14191419

14201420
_f16_ok = True
1421+
_output_type_depends_on_input_value = True
1422+
14211423
__props__ = ()
14221424

14231425
def make_node(self, value, *shape):
@@ -3817,6 +3819,8 @@ def perform(self, node, inputs, outputs):
38173819
class AllocEmpty(COp):
38183820
"""Implement Alloc on the cpu, but without initializing memory."""
38193821

3822+
_output_type_depends_on_input_value = True
3823+
38203824
__props__ = ("dtype",)
38213825
params_type = ParamsType(typecode=int32)
38223826

pytensor/tensor/extra_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,8 @@ def broadcast_shape_iter(
16301630
class BroadcastTo(COp):
16311631
"""An `Op` for `numpy.broadcast_to`."""
16321632

1633+
_output_type_depends_on_input_value = True
1634+
16331635
__props__ = ()
16341636

16351637
view_map = {0: [0]}

pytensor/tensor/random/op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class RandomVariable(Op):
9191
9292
"""
9393

94+
_output_type_depends_on_input_value = True
95+
9496
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace")
9597
default_output = 1
9698

pytensor/tensor/shape.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ class SpecifyShape(COp):
388388
view_map = {0: [0]}
389389
__props__ = ()
390390
_f16_ok = True
391+
_output_type_depends_on_input_value = True
391392

392393
def make_node(self, x, *shape):
393394
from pytensor.tensor.basic import get_underlying_scalar_constant_value
@@ -587,6 +588,7 @@ class Reshape(COp):
587588

588589
view_map = {0: [0]} # output 0 is potentially aliased to inputs [0]
589590
_f16_ok = True
591+
_output_type_depends_on_input_value = True
590592

591593
check_input = False
592594
__props__ = ("ndim",)

tests/tensor/random/test_basic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pytensor.graph.basic import Constant, Variable, graph_inputs
1515
from pytensor.graph.fg import FunctionGraph
1616
from pytensor.graph.op import get_test_value
17+
from pytensor.graph.replace import clone_replace
1718
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1819
from pytensor.tensor.random.basic import (
1920
bernoulli,
@@ -57,7 +58,7 @@
5758
weibull,
5859
)
5960
from pytensor.tensor.rewriting.shape import ShapeFeature
60-
from pytensor.tensor.type import iscalar, scalar, tensor
61+
from pytensor.tensor.type import iscalar, scalar, tensor, vector
6162
from tests.unittest_tools import create_pytensor_param
6263

6364

@@ -1422,3 +1423,19 @@ def test_pickle():
14221423
a_unpkl = pickle.loads(a_pkl)
14231424

14241425
assert a_unpkl.owner.op._props() == sample_a.owner.op._props()
1426+
1427+
1428+
def test_rebuild():
1429+
x = vector(shape=(50,))
1430+
x_test = np.zeros((50,), dtype=config.floatX)
1431+
y = normal(size=x.shape)
1432+
assert y.type.shape == (50,)
1433+
assert y.shape.eval({x: x_test}) == (50,)
1434+
assert y.eval({x: x_test}).shape == (50,)
1435+
1436+
x_new = vector(shape=(100,))
1437+
x_new_test = np.zeros((100,), dtype=config.floatX)
1438+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
1439+
assert y_new.type.shape == (100,)
1440+
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
1441+
assert y_new.eval({x_new: x_new_test}).shape == (100,)

tests/tensor/test_basic.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytensor.gradient import grad, hessian
1717
from pytensor.graph.basic import Apply
1818
from pytensor.graph.op import Op
19+
from pytensor.graph.replace import clone_replace
1920
from pytensor.misc.safe_asarray import _asarray
2021
from pytensor.raise_op import Assert
2122
from pytensor.scalar import autocast_float, autocast_float_as
@@ -818,6 +819,22 @@ def test_full(self):
818819
res = pytensor.function([], full_at, mode=self.mode)()
819820
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64"))
820821

822+
@pytest.mark.parametrize("func", (at.zeros, at.empty))
823+
def test_rebuild(self, func):
824+
x = vector(shape=(50,))
825+
x_test = np.zeros((50,), dtype=config.floatX)
826+
y = func(x.shape)
827+
assert y.type.shape == (50,)
828+
assert y.shape.eval({x: x_test}) == (50,)
829+
assert y.eval({x: x_test}).shape == (50,)
830+
831+
x_new = vector(shape=(100,))
832+
x_new_test = np.zeros((100,), dtype=config.floatX)
833+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
834+
assert y_new.type.shape == (100,)
835+
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
836+
assert y_new.eval({x_new: x_new_test}).shape == (100,)
837+
821838

822839
def test_infer_shape():
823840
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):

tests/tensor/test_extra_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.compile.mode import Mode
1010
from pytensor.configdefaults import config
1111
from pytensor.graph.basic import Constant, applys_between
12+
from pytensor.graph.replace import clone_replace
1213
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1314
from pytensor.raise_op import Assert
1415
from pytensor.tensor.elemwise import DimShuffle
@@ -1393,6 +1394,22 @@ def test_inplace(self):
13931394

13941395
assert advincsub_node.op.inplace is False
13951396

1397+
def test_rebuild(self):
1398+
x = vector(shape=(50,))
1399+
x_test = np.zeros((50,), dtype=config.floatX)
1400+
i = 0
1401+
y = broadcast_to(i, x.shape)
1402+
assert y.type.shape == (50,)
1403+
assert y.shape.eval({x: x_test}) == (50,)
1404+
assert y.eval({x: x_test}).shape == (50,)
1405+
1406+
x_new = vector(shape=(100,))
1407+
x_new_test = np.zeros((100,), dtype=config.floatX)
1408+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
1409+
assert y_new.type.shape == (100,)
1410+
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
1411+
assert y_new.eval({x_new: x_new_test}).shape == (100,)
1412+
13961413

13971414
def test_broadcast_arrays():
13981415
x, y = at.dvector(), at.dmatrix()

tests/tensor/test_shape.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pytensor.configdefaults import config
88
from pytensor.graph.basic import Variable
99
from pytensor.graph.fg import FunctionGraph
10+
from pytensor.graph.replace import clone_replace
1011
from pytensor.graph.type import Type
1112
from pytensor.misc.safe_asarray import _asarray
1213
from pytensor.scalar.basic import ScalarConstant
@@ -337,6 +338,21 @@ def test_more_shapes(self):
337338
Reshape,
338339
)
339340

341+
def test_rebuild(self):
342+
x = as_tensor_variable(50)
343+
i = vector("i")
344+
i_test = np.zeros((100,), dtype=config.floatX)
345+
y = reshape(i, (100 // x, x))
346+
assert y.type.shape == (2, 50)
347+
assert tuple(y.shape.eval({i: i_test})) == (2, 50)
348+
assert y.eval({i: i_test}).shape == (2, 50)
349+
350+
x_new = as_tensor_variable(25)
351+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
352+
assert y_new.type.shape == (4, 25)
353+
assert tuple(y_new.shape.eval({i: i_test})) == (4, 25)
354+
assert y_new.eval({i: i_test}).shape == (4, 25)
355+
340356

341357
def test_shape_i_hash():
342358
assert isinstance(Shape_i(np.int64(1)).__hash__(), int)
@@ -524,6 +540,22 @@ def test_specify_shape_in_grad(self):
524540
z_grad = grad(z.sum(), wrt=x)
525541
assert isinstance(z_grad.owner.op, SpecifyShape)
526542

543+
def test_rebuild(self):
544+
x = as_tensor_variable(50)
545+
i = matrix("i")
546+
i_test = np.zeros((4, 50), dtype=config.floatX)
547+
y = specify_shape(i, (None, x))
548+
assert y.type.shape == (None, 50)
549+
assert tuple(y.shape.eval({i: i_test})) == (4, 50)
550+
assert y.eval({i: i_test}).shape == (4, 50)
551+
552+
x_new = as_tensor_variable(100)
553+
i_test = np.zeros((4, 100), dtype=config.floatX)
554+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
555+
assert y_new.type.shape == (None, 100)
556+
assert tuple(y_new.shape.eval({i: i_test})) == (4, 100)
557+
assert y_new.eval({i: i_test}).shape == (4, 100)
558+
527559

528560
class TestSpecifyBroadcastable:
529561
def test_basic(self):

0 commit comments

Comments
 (0)