Skip to content

numba_funcify_Elemwise can end up with parent_node in its kwargs when vectorizing OpFromGraph #1507

Open
@jessegrabowski

Description

@jessegrabowski

Description

Trying to compile a graph vectorized with vectorize_graph that contains an OpFromGraph into numba mode results in an error:

X = pt.dmatrix("X", shape=(None, None))
X_batched = pt.tensor("X", shape=(None, None, None))

z = X + 1

results = OpFromGraph(
    inputs=[X],
    outputs=[z],
)(X)

z_vec = vectorize_graph(results, {X: X_batched})
fn = pytensor.function(
    [X_batched],
    [z_vec],
    mode='NUMBA',
)
Full traceback
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[36], line 12
      6 results = OpFromGraph(
      7     inputs=[X],
      8     outputs=[z],
      9 )(X)
     11 z_vec = vectorize_graph(results, {X: X_batched})
---> 12 fn = pytensor.function(
     13     [X_batched],
     14     [z_vec],
     15     mode='NUMBA',
     16 )

File ~/Documents/Python/pytensor/pytensor/compile/function/__init__.py:332, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, trust_input)
    321     fn = orig_function(
    322         inputs,
    323         outputs,
   (...)    327         trust_input=trust_input,
    328     )
    329 else:
    330     # note: pfunc will also call orig_function -- orig_function is
    331     #      a choke point that all compilation must pass through
--> 332     fn = pfunc(
    333         params=inputs,
    334         outputs=outputs,
    335         mode=mode,
    336         updates=updates,
    337         givens=givens,
    338         no_default_updates=no_default_updates,
    339         accept_inplace=accept_inplace,
    340         name=name,
    341         rebuild_strict=rebuild_strict,
    342         allow_input_downcast=allow_input_downcast,
    343         on_unused_input=on_unused_input,
    344         profile=profile,
    345         output_keys=output_keys,
    346         trust_input=trust_input,
    347     )
    348 return fn

File ~/Documents/Python/pytensor/pytensor/compile/function/pfunc.py:466, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph, trust_input)
    452     profile = ProfileStats(message=profile)
    454 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    455     params,
    456     outputs,
   (...)    463     fgraph=fgraph,
    464 )
--> 466 return orig_function(
    467     inputs,
    468     cloned_outputs,
    469     mode,
    470     accept_inplace=accept_inplace,
    471     name=name,
    472     profile=profile,
    473     on_unused_input=on_unused_input,
    474     output_keys=output_keys,
    475     fgraph=fgraph,
    476     trust_input=trust_input,
    477 )

File ~/Documents/Python/pytensor/pytensor/compile/function/types.py:1835, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph, trust_input)
   1822     m = Maker(
   1823         inputs,
   1824         outputs,
   (...)   1832         trust_input=trust_input,
   1833     )
   1834     with config.change_flags(compute_test_value="off"):
-> 1835         fn = m.create(defaults)
   1836 finally:
   1837     if profile and fn:

File ~/Documents/Python/pytensor/pytensor/compile/function/types.py:1719, in FunctionMaker.create(self, input_storage, storage_map)
   1716 start_import_time = pytensor.link.c.cmodule.import_time
   1718 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1719     _fn, _i, _o = self.linker.make_thunk(
   1720         input_storage=input_storage_lists, storage_map=storage_map
   1721     )
   1723 end_linker = time.perf_counter()
   1725 linker_time = end_linker - start_linker

File ~/Documents/Python/pytensor/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    238 def make_thunk(
    239     self,
    240     input_storage: Optional["InputStorageType"] = None,
   (...)    243     **kwargs,
    244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245     return self.make_all(
    246         input_storage=input_storage,
    247         output_storage=output_storage,
    248         storage_map=storage_map,
    249     )[:3]

File ~/Documents/Python/pytensor/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
    692 for k in storage_map:
    693     compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
    696     compute_map, nodes, input_storage, output_storage, storage_map
    697 )
    699 [fn] = thunks
    700 fn.jit_fn = jit_fn

File ~/Documents/Python/pytensor/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
    644 # This is a bit hackish, but we only return one of the output nodes
    645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
    648     self.fgraph,
    649     order=order,
    650     input_storage=input_storage,
    651     output_storage=output_storage,
    652     storage_map=storage_map,
    653 )
    655 thunk_inputs = self.create_thunk_inputs(storage_map)
    656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]

File ~/Documents/Python/pytensor/pytensor/link/numba/linker.py:10, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs)
      7 def fgraph_convert(self, fgraph, **kwargs):
      8     from pytensor.link.numba.dispatch import numba_funcify
---> 10     return numba_funcify(fgraph, **kwargs)

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/Documents/Python/pytensor/pytensor/link/numba/dispatch/basic.py:379, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    372 @numba_funcify.register(FunctionGraph)
    373 def numba_funcify_FunctionGraph(
    374     fgraph,
   (...)    377     **kwargs,
    378 ):
--> 379     return fgraph_to_python(
    380         fgraph,
    381         numba_funcify,
    382         type_conversion_fn=numba_typify,
    383         fgraph_name=fgraph_name,
    384         **kwargs,
    385     )

File ~/Documents/Python/pytensor/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
    734 body_assigns = []
    735 for node in order:
--> 736     compiled_func = op_conversion_fn(
    737         node.op, node=node, storage_map=storage_map, **kwargs
    738     )
    740     # Create a local alias with a unique name
    741     local_compiled_func_name = unique_name(compiled_func)

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/Documents/Python/pytensor/pytensor/link/numba/dispatch/blockwise.py:31, in numba_funcify_Blockwise(op, node, **kwargs)
     26 core_shapes_len = tuple(get_vector_length(sh) for sh in node.inputs[nin:])
     28 core_node = blockwise_op._create_dummy_core_node(
     29     cast(tuple[TensorVariable], blockwise_node.inputs)
     30 )
---> 31 core_op_fn = numba_funcify(
     32     core_op,
     33     node=core_node,
     34     parent_node=node,
     35     **kwargs,
     36 )
     37 core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
     39 batch_ndim = blockwise_op.batch_ndim(node)

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/Documents/Python/pytensor/pytensor/link/numba/dispatch/basic.py:355, in numba_funcify_OpFromGraph(op, node, **kwargs)
    349 add_supervisor_to_fgraph(
    350     fgraph=fgraph,
    351     input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
    352     accept_inplace=True,
    353 )
    354 NUMBA.optimizer(fgraph)
--> 355 fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
    357 if len(op.fgraph.outputs) == 1:
    359     @numba_njit
    360     def opfromgraph(*inputs):

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/Documents/Python/pytensor/pytensor/link/numba/dispatch/basic.py:379, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    372 @numba_funcify.register(FunctionGraph)
    373 def numba_funcify_FunctionGraph(
    374     fgraph,
   (...)    377     **kwargs,
    378 ):
--> 379     return fgraph_to_python(
    380         fgraph,
    381         numba_funcify,
    382         type_conversion_fn=numba_typify,
    383         fgraph_name=fgraph_name,
    384         **kwargs,
    385     )

File ~/Documents/Python/pytensor/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
    734 body_assigns = []
    735 for node in order:
--> 736     compiled_func = op_conversion_fn(
    737         node.op, node=node, storage_map=storage_map, **kwargs
    738     )
    740     # Create a local alias with a unique name
    741     local_compiled_func_name = unique_name(compiled_func)

File ~/mambaforge/envs/pytensor-dev/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/Documents/Python/pytensor/pytensor/link/numba/dispatch/elemwise.py:270, in numba_funcify_Elemwise(op, node, **kwargs)
    267 scalar_inputs = [get_scalar_type(dtype=input.dtype)() for input in node.inputs]
    268 scalar_node = op.scalar_op.make_node(*scalar_inputs)
--> 270 scalar_op_fn = numba_funcify(
    271     op.scalar_op,
    272     node=scalar_node,
    273     parent_node=node,
    274     **kwargs,
    275 )
    277 nin = len(node.inputs)
    278 nout = len(node.outputs)

TypeError: pytensor.link.numba.dispatch.basic.numba_funcify() got multiple values for keyword argument 'parent_node'

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions