Skip to content

ADVI errors in numba mode for StudentT likelihood when total_size is set #7778

@jessegrabowski

Description

@jessegrabowski

Description

import pymc as pm
import numpy as np

rng = np.random.default_rng()

with pm.Model() as m:
    data = pm.Data('data', rng.normal(size=(1000, 5)))
    obs = pm.Data('obs', rng.normal(size=(1000,)))
    
    data_batch, obs_batch = pm.Minibatch(data, obs, batch_size=128)
    
    beta = pm.Normal('beta', size=(5,))
    mu = data_batch @ beta
    sigma = pm.Exponential('sigma', 1)
    
    y_hat = pm.StudentT('y_hat', mu=mu, sigma=sigma, nu=3, observed=obs_batch, total_size=1000)
    
    idata = pm.fit(n=1_000_000, 
                   compile_kwargs={'mode':'NUMBA'})
Full Traceback
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[1], line 18
     14 sigma = pm.Exponential('sigma', 1)
     16 y_hat = pm.StudentT('y_hat', mu=mu, sigma=sigma, nu=3, observed=obs_batch, total_size=1000)
---> 18 idata = pm.fit(n=1_000_000, 
     19                compile_kwargs={'mode':'NUMBA'})

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/inference.py:775, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
    773 else:
    774     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 775 return inference.fit(n, **kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/inference.py:158, in Inference.fit(self, n, score, callbacks, progressbar, progressbar_theme, **kwargs)
    156     callbacks = []
    157 score = self._maybe_score(score)
--> 158 step_func = self.objective.step_function(score=score, **kwargs)
    160 if score:
    161     state = self._iterate_with_loss(
    162         0, n, step_func, progressbar, progressbar_theme, callbacks
    163     )

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     41 @wraps(f)
     42 def res(*args, **kwargs):
     43     with self:
---> 44         return f(*args, **kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/opvi.py:405, in ObjectiveFunction.step_function(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint, score, compile_kwargs, fn_kwargs)
    403 seed = self.approx.rng.randint(2**30, dtype=np.int64)
    404 if score:
--> 405     step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **compile_kwargs)
    406 else:
    407     step_fn = compile([], [], updates=updates, random_seed=seed, **compile_kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/pytensorf.py:947, in compile(inputs, outputs, random_seed, mode, **kwargs)
    945 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
    946 mode = Mode(linker=mode.linker, optimizer=opt_qry)
--> 947 pytensor_function = pytensor.function(
    948     inputs,
    949     outputs,
    950     updates={**rng_updates, **kwargs.pop("updates", {})},
    951     mode=mode,
    952     **kwargs,
    953 )
    954 return pytensor_function

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/__init__.py:332, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, trust_input)
    321     fn = orig_function(
    322         inputs,
    323         outputs,
   (...)    327         trust_input=trust_input,
    328     )
    329 else:
    330     # note: pfunc will also call orig_function -- orig_function is
    331     #      a choke point that all compilation must pass through
--> 332     fn = pfunc(
    333         params=inputs,
    334         outputs=outputs,
    335         mode=mode,
    336         updates=updates,
    337         givens=givens,
    338         no_default_updates=no_default_updates,
    339         accept_inplace=accept_inplace,
    340         name=name,
    341         rebuild_strict=rebuild_strict,
    342         allow_input_downcast=allow_input_downcast,
    343         on_unused_input=on_unused_input,
    344         profile=profile,
    345         output_keys=output_keys,
    346         trust_input=trust_input,
    347     )
    348 return fn

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:466, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph, trust_input)
    452     profile = ProfileStats(message=profile)
    454 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    455     params,
    456     outputs,
   (...)    463     fgraph=fgraph,
    464 )
--> 466 return orig_function(
    467     inputs,
    468     cloned_outputs,
    469     mode,
    470     accept_inplace=accept_inplace,
    471     name=name,
    472     profile=profile,
    473     on_unused_input=on_unused_input,
    474     output_keys=output_keys,
    475     fgraph=fgraph,
    476     trust_input=trust_input,
    477 )

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/types.py:1833, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph, trust_input)
   1820     m = Maker(
   1821         inputs,
   1822         outputs,
   (...)   1830         trust_input=trust_input,
   1831     )
   1832     with config.change_flags(compute_test_value="off"):
-> 1833         fn = m.create(defaults)
   1834 finally:
   1835     if profile and fn:

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/types.py:1717, in FunctionMaker.create(self, input_storage, storage_map)
   1714 start_import_time = pytensor.link.c.cmodule.import_time
   1716 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1717     _fn, _i, _o = self.linker.make_thunk(
   1718         input_storage=input_storage_lists, storage_map=storage_map
   1719     )
   1721 end_linker = time.perf_counter()
   1723 linker_time = end_linker - start_linker

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    238 def make_thunk(
    239     self,
    240     input_storage: Optional["InputStorageType"] = None,
   (...)    243     **kwargs,
    244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245     return self.make_all(
    246         input_storage=input_storage,
    247         output_storage=output_storage,
    248         storage_map=storage_map,
    249     )[:3]

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
    692 for k in storage_map:
    693     compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
    696     compute_map, nodes, input_storage, output_storage, storage_map
    697 )
    699 [fn] = thunks
    700 fn.jit_fn = jit_fn

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
    644 # This is a bit hackish, but we only return one of the output nodes
    645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
    648     self.fgraph,
    649     order=order,
    650     input_storage=input_storage,
    651     output_storage=output_storage,
    652     storage_map=storage_map,
    653 )
    655 thunk_inputs = self.create_thunk_inputs(storage_map)
    656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/linker.py:10, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs)
      7 def fgraph_convert(self, fgraph, **kwargs):
      8     from pytensor.link.numba.dispatch import numba_funcify
---> 10     return numba_funcify(fgraph, **kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:380, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    373 @numba_funcify.register(FunctionGraph)
    374 def numba_funcify_FunctionGraph(
    375     fgraph,
   (...)    378     **kwargs,
    379 ):
--> 380     return fgraph_to_python(
    381         fgraph,
    382         numba_funcify,
    383         type_conversion_fn=numba_typify,
    384         fgraph_name=fgraph_name,
    385         **kwargs,
    386     )

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
    734 body_assigns = []
    735 for node in order:
--> 736     compiled_func = op_conversion_fn(
    737         node.op, node=node, storage_map=storage_map, **kwargs
    738     )
    740     # Create a local alias with a unique name
    741     local_compiled_func_name = unique_name(compiled_func)

File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/random.py:401, in numba_funcify_RandomVariable(op, node, **kwargs)
    398 core_shape_len = get_vector_length(core_shape)
    399 inplace = rv_op.inplace
--> 401 core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
    402 nin = 1 + len(dist_params)  # rng + params
    403 core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)

File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/random.py:47, in numba_core_rv_funcify(op, node)
     44 @singledispatch
     45 def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
     46     """Return the core function for a random variable operation."""
---> 47     raise NotImplementedError(f"Core implementation of {op} not implemented.")

NotImplementedError: Core implementation of t_rv{"(),(),()->()"} not implemented.

Interestingly, it works fine if you change total_size = None.

Metadata

Metadata

Assignees

No one assigned

    Labels

    VIVariational Inferencebugnumba

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions