Skip to content

Allow vectorization of graphs withIfElse #1510

Open
@jessegrabowski

Description

@jessegrabowski

Description

As the title suggests, graphs with IfElse will fail at runtime if you try to vectorize them. The error raised reminds me of #1425

import pytensor.tensor as pt
import pytensor
import numpy as np

x = pt.dvector('x')
a = pt.dscalar('a')
z = pytensor.ifelse(pt.ge(a, 0), x, x.zeros_like())

x_batched = pt.dmatrix('x')
z_batched = pytensor.graph.replace.vectorize_graph(z, {x:x_batched})
fn = pytensor.function([a, x_batched], z_batched)

fn(1.0, np.arange(9).reshape((3 ,3)))

Raises the following:

Full Traceback
[Skip to Main](http://localhost:8888/notebooks/Untitled.ipynb#first-cell)
Untitled
Last Checkpoint: 18 days ago
[Python 3 (ipykernel)]
Selection deleted
import pytensor.tensor as pt
import pytensor
import numpy as np

x = pt.dvector('x')
a = pt.dscalar('a')
z = pytensor.ifelse(pt.ge(a, 0), x, x.zeros_like())

x_batched = pt.dmatrix('x')
z_batched = pytensor.graph.replace.vectorize_graph(z, {x:x_batched})
fn = pytensor.function([a, x_batched], z_batched)

fn(1.0, np.arange(10)[:, None])

---------------------------------------------------------------------------


TypeError                                 Traceback (most recent call last)
File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/compile/function/types.py:1039, in Function.__call__(self, output_subset, *args, **kwargs)
   1038 try:
-> 1039     outputs = vm() if output_subset is None else vm(output_subset=output_subset)
   1040 except Exception:

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/graph/op.py:544, in Op.make_py_thunk.<locals>.rval(p, i, o, n, cm)
    536 @is_thunk_type
    537 def rval(
    538     p=p,
   (...)    542     cm=node_compute_map,
    543 ):
--> 544     r = p(n, [x[0] for x in i], o)
    545     for entry in cm:

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/tensor/blockwise.py:489, in Blockwise.perform(self, node, inputs, output_storage)
    488     gufunc = node.tag.gufunc = self._create_node_gufunc(node, impl=None)
--> 489 for out_storage, result in zip(output_storage, gufunc(*inputs)):
    490     out_storage[0] = result

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/tensor/blockwise.py:96, in _vectorize_node_perform.<locals>.vectorized_perform(batch_bcast_patterns, batch_ndim, single_in, core_thunk, core_input_storage, core_output_storage, core_storage, *args)
     95     core_input[0] = np.asarray(arg[index0])
---> 96 core_thunk()
     97 outputs = tuple(
     98     empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype)
     99     for core_output in core_output_storage
    100 )

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/ifelse.py:295, in IfElse.make_thunk.<locals>.thunk()
    294 def thunk():
--> 295     if not compute_map[cond][0]:
    296         return [0]

TypeError: 'NoneType' object is not subscriptable

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[2], line 13
     10 z_batched = pytensor.graph.replace.vectorize_graph(z, {x:x_batched})
     11 fn = pytensor.function([a, x_batched], z_batched)
---> 13 fn(1.0, np.arange(10)[:, None])

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/compile/function/types.py:1049, in Function.__call__(self, output_subset, *args, **kwargs)
   1047     if hasattr(self.vm, "thunks"):
   1048         thunk = self.vm.thunks[self.vm.position_of_error]
-> 1049     raise_with_op(
   1050         self.maker.fgraph,
   1051         node=self.vm.nodes[self.vm.position_of_error],
   1052         thunk=thunk,
   1053         storage_map=getattr(self.vm, "storage_map", None),
   1054     )
   1055 else:
   1056     # old-style linkers raise their own exceptions
   1057     raise

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/link/utils.py:526, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    521     warnings.warn(
    522         f"{exc_type} error does not allow us to add an extra error message"
    523     )
    524     # Some exception need extra parameter in inputs. So forget the
    525     # extra long error message in that case.
--> 526 raise exc_value.with_traceback(exc_trace)

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/compile/function/types.py:1039, in Function.__call__(self, output_subset, *args, **kwargs)
   1037     t0_fn = time.perf_counter()
   1038 try:
-> 1039     outputs = vm() if output_subset is None else vm(output_subset=output_subset)
   1040 except Exception:
   1041     self._restore_defaults()

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/graph/op.py:544, in Op.make_py_thunk.<locals>.rval(p, i, o, n, cm)
    536 @is_thunk_type
    537 def rval(
    538     p=p,
   (...)    542     cm=node_compute_map,
    543 ):
--> 544     r = p(n, [x[0] for x in i], o)
    545     for entry in cm:
    546         entry[0] = True

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/tensor/blockwise.py:489, in Blockwise.perform(self, node, inputs, output_storage)
    487 except AttributeError:
    488     gufunc = node.tag.gufunc = self._create_node_gufunc(node, impl=None)
--> 489 for out_storage, result in zip(output_storage, gufunc(*inputs)):
    490     out_storage[0] = result

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/tensor/blockwise.py:96, in _vectorize_node_perform.<locals>.vectorized_perform(batch_bcast_patterns, batch_ndim, single_in, core_thunk, core_input_storage, core_output_storage, core_storage, *args)
     94 for core_input, arg in zip(core_input_storage, args):
     95     core_input[0] = np.asarray(arg[index0])
---> 96 core_thunk()
     97 outputs = tuple(
     98     empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype)
     99     for core_output in core_output_storage
    100 )
    101 for output, core_output in zip(outputs, core_output_storage):

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/ifelse.py:295, in IfElse.make_thunk.<locals>.thunk()
    294 def thunk():
--> 295     if not compute_map[cond][0]:
    296         return [0]
    297     else:

TypeError: 'NoneType' object is not subscriptable
Apply node that caused the error: Blockwise{if{}, (),(i10),(i20)->(o00)}(Ge.0, x, Alloc.0)
Toposort index: 4
Inputs types: [TensorType(bool, shape=(1,)), TensorType(float64, shape=(None, None)), TensorType(float64, shape=(1, None))]
Inputs shapes: [(1,), (10, 1), (1, 1)]
Inputs strides: [(1,), (8, 8), (8, 8)]
Inputs values: [array([ True]), 'not shown', array([[0.]])]
Outputs clients: [[output[0](Blockwise{if{}, (),(i10),(i20)->(o00)}.0)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Users/jessegrabowski/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
    coro.send(None)
  File "/Users/jessegrabowski/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3367, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/jessegrabowski/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3612, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/Users/jessegrabowski/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3672, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_7777/3079844208.py", line 10, in <module>
    z_batched = pytensor.graph.replace.vectorize_graph(z, {x:x_batched})
  File "/Users/jessegrabowski/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/graph/replace.py", line 301, in vectorize_graph
    vect_node = vectorize_node(node, *vect_inputs)
  File "/Users/jessegrabowski/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pytensor/graph/replace.py", line 217, in vectorize_node
    return _vectorize_node(op, node, *batched_inputs)
  File "/Users/jessegrabowski/mambaforge/envs/readystate-bonds/lib/python3.11/functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions