Skip to content

Allow rebuilding graphs when output type depends on input values #280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,24 @@ def clone_with_new_inputs(
assert isinstance(inputs, (list, tuple))
remake_node = False
new_inputs: List["Variable"] = list(inputs)

# Some Ops like Alloc require the node to always be rebuilt in non-strict mode
# as the output type depends on the input values and not just their types
output_type_depends_on_input_value = self.op._output_type_depends_on_input_value

for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
if curr.type != new.type:
# Check if the input type changed or if the Op has output types that depend on input values
if (curr.type != new.type) or output_type_depends_on_input_value:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super clear the branching logic from the outsider glance, can you please add comments that explain what exactly happens in each branch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# In strict mode, the cloned graph is assumed to be mathematically equivalent to the original one.
# We only need to rebuild a node when the new input has a different, but compatible, type.
# This can happen e.g., when we provide a new input with a more specialized static shape.
if strict:
new_i = curr.type.filter_variable(new)
new_inputs[i] = new_i

if curr.type != new_i.type:
remake_node = True
# Otherwise, we always rebuild the node
else:
remake_node = True

Expand Down
9 changes: 9 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,15 @@ class Op(MetaObject):
otypes: Optional[Sequence["Type"]] = None
params_type: Optional[ParamsType] = None

_output_type_depends_on_input_value = False
"""
Whether the static output type depends on the inferred value of one of the inputs.
(e.g, via constant folding or static shape inference).

This information is needed when rebuilding a graph with new inputs,
as nodes with these Ops must be rebuilt even if the input types haven't changed.
"""

def make_node(self, *inputs: Variable) -> Apply:
"""Construct an `Apply` node that represent the application of this operation to the given inputs.

Expand Down
4 changes: 4 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,8 @@ class Alloc(COp):
"""

_f16_ok = True
_output_type_depends_on_input_value = True

__props__ = ()

def make_node(self, value, *shape):
Expand Down Expand Up @@ -3819,6 +3821,8 @@ def perform(self, node, inputs, outputs):
class AllocEmpty(COp):
"""Implement Alloc on the cpu, but without initializing memory."""

_output_type_depends_on_input_value = True

__props__ = ("dtype",)
params_type = ParamsType(typecode=int32)

Expand Down
2 changes: 2 additions & 0 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,6 +1561,8 @@ def broadcast_shape_iter(
class BroadcastTo(COp):
"""An `Op` for `numpy.broadcast_to`."""

_output_type_depends_on_input_value = True

__props__ = ()

view_map = {0: [0]}
Expand Down
2 changes: 2 additions & 0 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class RandomVariable(Op):

"""

_output_type_depends_on_input_value = True

__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace")
default_output = 1

Expand Down
11 changes: 3 additions & 8 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,7 @@ def shape(x: Union[np.ndarray, Number, Variable]) -> Variable:
if not isinstance(x, Variable):
x = at.as_tensor_variable(x)

x_type = x.type

if isinstance(x_type, TensorType) and all(s is not None for s in x_type.shape):
res = at.as_tensor_variable(x_type.shape, ndim=1, dtype=np.int64)
else:
res = _shape(x)

return res
return _shape(x)


@_get_vector_length.register(Shape)
Expand Down Expand Up @@ -395,6 +388,7 @@ class SpecifyShape(COp):
view_map = {0: [0]}
__props__ = ()
_f16_ok = True
_output_type_depends_on_input_value = True

def make_node(self, x, *shape):
from pytensor.tensor.basic import get_underlying_scalar_constant_value
Expand Down Expand Up @@ -594,6 +588,7 @@ class Reshape(COp):

view_map = {0: [0]} # output 0 is potentially aliased to inputs [0]
_f16_ok = True
_output_type_depends_on_input_value = True

check_input = False
__props__ = ("ndim",)
Expand Down
5 changes: 1 addition & 4 deletions pytensor/tensor/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytensor.graph.basic import Constant, OptionalApplyType, Variable
from pytensor.graph.utils import MetaType
from pytensor.scalar import ComplexError, IntegerDivisionError
from pytensor.tensor import _get_vector_length, as_tensor_variable
from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst
Expand Down Expand Up @@ -259,9 +259,6 @@ def transpose(self, *axes):

@property
def shape(self):
if not any(s is None for s in self.type.shape):
return as_tensor_variable(self.type.shape, ndim=1, dtype=np.int64)

return at.shape(self)

@property
Expand Down
5 changes: 2 additions & 3 deletions tests/scan/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,13 +1477,12 @@ def test_while_scan_taps_and_map(self):
f(x0=0, seq=test_seq, n_steps=0)

# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
# If a MissingInputError is raised, it means the rewrite failed
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
_, _, ys_trace, len_zs = scan_node.inputs
debug_fn = pytensor.function(
[n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
[x0, n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
)
stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200)
stored_ys_steps, stored_zs_steps = debug_fn(x0=0, n_steps=200)
assert stored_ys_steps == 2
assert stored_zs_steps == 1

Expand Down
19 changes: 18 additions & 1 deletion tests/tensor/random/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.random.basic import (
bernoulli,
Expand Down Expand Up @@ -57,7 +58,7 @@
weibull,
)
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.type import iscalar, scalar, tensor
from pytensor.tensor.type import iscalar, scalar, tensor, vector
from tests.unittest_tools import create_pytensor_param


Expand Down Expand Up @@ -1422,3 +1423,19 @@ def test_pickle():
a_unpkl = pickle.loads(a_pkl)

assert a_unpkl.owner.op._props() == sample_a.owner.op._props()


def test_rebuild():
x = vector(shape=(50,))
x_test = np.zeros((50,), dtype=config.floatX)
y = normal(size=x.shape)
assert y.type.shape == (50,)
assert y.shape.eval({x: x_test}) == (50,)
assert y.eval({x: x_test}).shape == (50,)

x_new = vector(shape=(100,))
x_new_test = np.zeros((100,), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)
19 changes: 18 additions & 1 deletion tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytensor.gradient import grad, hessian
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.graph.replace import clone_replace
from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import Assert
from pytensor.scalar import autocast_float, autocast_float_as
Expand Down Expand Up @@ -818,6 +819,22 @@ def test_full(self):
res = pytensor.function([], full_at, mode=self.mode)()
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64"))

@pytest.mark.parametrize("func", (at.zeros, at.empty))
def test_rebuild(self, func):
x = vector(shape=(50,))
x_test = np.zeros((50,), dtype=config.floatX)
y = func(x.shape)
assert y.type.shape == (50,)
assert y.shape.eval({x: x_test}) == (50,)
assert y.eval({x: x_test}).shape == (50,)

x_new = vector(shape=(100,))
x_new_test = np.zeros((100,), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)


def test_infer_shape():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
Expand Down Expand Up @@ -3506,7 +3523,7 @@ def test_vector(self):
def test_scalar(self):
x = scalar()
y = np.array(7, dtype=config.floatX)
assert y.size == function([], x.size)()
assert y.size == function([x], x.size)(y)

def test_shared(self):
# NB: we also test higher order tensors at the same time.
Expand Down
17 changes: 17 additions & 0 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, applys_between
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.raise_op import Assert
from pytensor.tensor.elemwise import DimShuffle
Expand Down Expand Up @@ -1399,6 +1400,22 @@ def test_inplace(self):

assert advincsub_node.op.inplace is False

def test_rebuild(self):
x = vector(shape=(50,))
x_test = np.zeros((50,), dtype=config.floatX)
i = 0
y = broadcast_to(i, x.shape)
assert y.type.shape == (50,)
assert y.shape.eval({x: x_test}) == (50,)
assert y.eval({x: x_test}).shape == (50,)

x_new = vector(shape=(100,))
x_new_test = np.zeros((100,), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (100,)
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)


def test_broadcast_arrays():
x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix()
Expand Down
5 changes: 1 addition & 4 deletions tests/tensor/test_nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytensor
from pytensor import function
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant
from pytensor.tensor.math import _allclose
from pytensor.tensor.nlinalg import (
SVD,
Expand Down Expand Up @@ -274,9 +273,7 @@ def test_det_grad():

def test_det_shape():
x = matrix()
det_shape = det(x).shape
assert isinstance(det_shape, Constant)
assert tuple(det_shape.data) == ()
assert det(x).type.shape == ()


def test_slogdet():
Expand Down
35 changes: 34 additions & 1 deletion tests/tensor/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
from pytensor.graph.type import Type
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar.basic import ScalarConstant
Expand All @@ -16,6 +17,7 @@
from pytensor.tensor.rewriting.shape import ShapeFeature
from pytensor.tensor.shape import (
Reshape,
Shape,
Shape_i,
SpecifyShape,
Unbroadcast,
Expand Down Expand Up @@ -336,6 +338,21 @@ def test_more_shapes(self):
Reshape,
)

def test_rebuild(self):
x = as_tensor_variable(50)
i = vector("i")
i_test = np.zeros((100,), dtype=config.floatX)
y = reshape(i, (100 // x, x))
assert y.type.shape == (2, 50)
assert tuple(y.shape.eval({i: i_test})) == (2, 50)
assert y.eval({i: i_test}).shape == (2, 50)

x_new = as_tensor_variable(25)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (4, 25)
assert tuple(y_new.shape.eval({i: i_test})) == (4, 25)
assert y_new.eval({i: i_test}).shape == (4, 25)


def test_shape_i_hash():
assert isinstance(Shape_i(np.int64(1)).__hash__(), int)
Expand Down Expand Up @@ -397,7 +414,7 @@ def test_fixed_shapes(self):
shape = as_tensor_variable([2])
y = specify_shape(x, shape)
assert y.type.shape == (2,)
assert y.shape.equals(shape)
assert isinstance(y.shape.owner.op, Shape)

def test_fixed_partial_shapes(self):
x = TensorType("floatX", (None, None))("x")
Expand Down Expand Up @@ -523,6 +540,22 @@ def test_specify_shape_in_grad(self):
z_grad = grad(z.sum(), wrt=x)
assert isinstance(z_grad.owner.op, SpecifyShape)

def test_rebuild(self):
x = as_tensor_variable(50)
i = matrix("i")
i_test = np.zeros((4, 50), dtype=config.floatX)
y = specify_shape(i, (None, x))
assert y.type.shape == (None, 50)
assert tuple(y.shape.eval({i: i_test})) == (4, 50)
assert y.eval({i: i_test}).shape == (4, 50)

x_new = as_tensor_variable(100)
i_test = np.zeros((4, 100), dtype=config.floatX)
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
assert y_new.type.shape == (None, 100)
assert tuple(y_new.shape.eval({i: i_test})) == (4, 100)
assert y_new.eval({i: i_test}).shape == (4, 100)


class TestSpecifyBroadcastable:
def test_basic(self):
Expand Down
12 changes: 10 additions & 2 deletions tests/tensor/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import pytensor
import tests.unittest_tools as utt
from pytensor.compile import DeepCopyOp
from pytensor.compile.mode import get_default_mode
from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq
from pytensor.tensor.shape import Shape
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
from pytensor.tensor.type import (
TensorType,
Expand Down Expand Up @@ -245,8 +247,14 @@ def test__getitem__newaxis(x, indices, new_order):

def test_fixed_shape_variable_basic():
x = TensorVariable(TensorType("int64", shape=(4,)), None)
assert isinstance(x.shape, Constant)
assert np.array_equal(x.shape.data, (4,))
assert x.type.shape == (4,)
assert isinstance(x.shape.owner.op, Shape)

shape_fn = pytensor.function([x], x.shape)
opt_shape = shape_fn.maker.fgraph.outputs[0]
assert isinstance(opt_shape.owner.op, DeepCopyOp)
assert isinstance(opt_shape.owner.inputs[0], Constant)
assert np.array_equal(opt_shape.owner.inputs[0].data, (4,))

x = TensorConstant(
TensorType("int64", shape=(None, None)), np.array([[1, 2], [2, 3]])
Expand Down