Skip to content

Commit b496127

Browse files
committed
Don't copy inputs in constant_fold
1 parent 7af0a87 commit b496127

File tree

3 files changed

+43
-20
lines changed

3 files changed

+43
-20
lines changed

pymc/pytensorf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1045,7 +1045,7 @@ def constant_fold(
10451045
attempting constant folding, and any old non-shared inputs will not work with
10461046
the returned outputs
10471047
"""
1048-
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], clone=True)
1048+
fg = FunctionGraph(outputs=xs, features=[ShapeFeature()], copy_inputs=False, clone=True)
10491049

10501050
# By default, rewrite_graph includes canonicalize which includes constant-folding as the final rewrite
10511051
folded_xs = rewrite_graph(fg).outputs

tests/model/transform/test_optimization.py

+15
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
from pymc import Deterministic, do
2121
from pymc.data import Data
2222
from pymc.distributions import HalfNormal, Normal
23+
from pymc.exceptions import NotConstantValueError
2324
from pymc.model import Model
2425
from pymc.model.transform.optimization import freeze_dims_and_data
26+
from pymc.pytensorf import constant_fold
2527

2628

2729
def test_freeze_dims_and_data():
@@ -144,3 +146,16 @@ def test_freeze_dim_after_do_intervention():
144146

145147
frozen_do_m = freeze_dims_and_data(do_m)
146148
assert frozen_do_m["x"].type.shape == (5,)
149+
150+
151+
def test_freeze_dims_and_data_partially_observed_rv():
152+
# Regression test for #7387
153+
154+
with Model(coords={"a": [0, 1, 2]}) as model:
155+
y = Normal("y", 0, observed=[0, 0, np.nan], dims="a")
156+
157+
with pytest.raises(NotConstantValueError):
158+
constant_fold([y.shape])
159+
160+
frozen_y = freeze_dims_and_data(model)["y"]
161+
assert constant_fold([frozen_y.shape]) == (3,)

tests/test_pytensorf.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -646,25 +646,33 @@ def test_reseed_rngs():
646646
assert rng.get_value().bit_generator.state == bit_generator.state
647647

648648

649-
def test_constant_fold():
650-
x = pt.random.normal(size=(5,))
651-
y = pt.arange(x.size)
652-
653-
res = constant_fold((y, y.shape))
654-
assert np.array_equal(res[0], np.arange(5))
655-
assert tuple(res[1]) == (5,)
656-
657-
658-
def test_constant_fold_raises():
659-
size = pytensor.shared(5)
660-
x = pt.random.normal(size=(size,))
661-
y = pt.arange(x.size)
662-
663-
with pytest.raises(NotConstantValueError):
664-
constant_fold((y, y.shape))
665-
666-
res = constant_fold((y, y.shape), raise_not_constant=False)
667-
assert tuple(res[1].eval()) == (5,)
649+
class TestConstantFold:
650+
def test_constant_fold(self):
651+
x = pt.random.normal(size=(5,))
652+
y = pt.arange(x.size)
653+
654+
res = constant_fold((y, y.shape))
655+
assert np.array_equal(res[0], np.arange(5))
656+
assert tuple(res[1]) == (5,)
657+
658+
def test_constant_fold_raises(self):
659+
size = pytensor.shared(5)
660+
x = pt.random.normal(size=(size,))
661+
y = pt.arange(x.size)
662+
663+
with pytest.raises(NotConstantValueError):
664+
constant_fold((y, y.shape))
665+
666+
res = constant_fold((y, y.shape), raise_not_constant=False)
667+
assert tuple(res[1].eval()) == (5,)
668+
669+
def test_inputs_preserved(self):
670+
# Make sure constant_folded graph depends on original graph inputs (not copies)
671+
# Regression test for #7387
672+
a = pt.scalar("a", dtype="int")
673+
out = pt.empty((a,))
674+
(out_shape,) = constant_fold((out.shape[0],), raise_not_constant=False)
675+
assert out_shape is a
668676

669677

670678
def test_replace_vars_in_graphs():

0 commit comments

Comments
 (0)