Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX backend fails for latent scan variables #6718

Open
jessegrabowski opened this issue May 15, 2023 · 10 comments
Open

JAX backend fails for latent scan variables #6718

jessegrabowski opened this issue May 15, 2023 · 10 comments

Comments

@jessegrabowski
Copy link
Member

jessegrabowski commented May 15, 2023

Describe the issue:

Not sure if this belongs here or in the pytensor repo. Putting it here because the minimal example I can come up with uses PyMC. If you make a scan variable, register it without observations, then use it for further computation, the graph will fail to compile.

Reproduceable code example:

import numpy as np
import pymc as pm
import pytensor
from pytensor.compile.mode import get_mode
from pymc.pytensorf import collect_default_updates

true_sigma = 0.1
true_eta = 0.25

# GRW with observation noise:
test_mu = np.random.normal(scale=true_sigma, size=100).cumsum()
test_obs = np.random.normal(loc=test_mu, scale=true_eta)

with pm.Model() as model:
    x0 = pm.Normal('x0')
    sigma = pm.HalfNormal('sigma')
    eta = pm.HalfNormal('eta')
    
    def step(*args):
        last_x, sigma = args
        x = pm.Normal.dist(mu=last_x, sigma=sigma)
        return x, collect_default_updates(args, [x])
    
    traj, updates = pytensor.scan(step, 
                                  outputs_info=[x0], 
                                  non_sequences=[sigma], 
                                  n_steps=100,
                                  mode=get_mode('JAX'))
    
    model.register_rv(traj, name='traj', initval='prior')
    obs = pm.Normal('obs', mu=traj, sigma=eta, observed=test_obs)
    idata = pm.sample(nuts_sampler='numpyro')

Error message:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/jax/_src/api_util.py:563, in shaped_abstractify(x)
    562 try:
--> 563   return _shaped_abstractify_handlers[type(x)](x)
    564 except KeyError:

KeyError: <class 'numpy.random._generator.Generator'>

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:216, in streamline.<locals>.streamline_nice_errors_f()
    215     for thunk, node in zip(thunks, order):
--> 216         thunk()
    217 except Exception:

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    663 def thunk(
    664     fgraph=self.fgraph,
    665     fgraph_jit=fgraph_jit,
    666     thunk_inputs=thunk_inputs,
    667     thunk_outputs=thunk_outputs,
    668 ):
--> 669     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    671     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):

    [... skipping hidden 6 frame]

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/jax/_src/api_util.py:554, in _shaped_abstractify_slow(x)
    553 else:
--> 554   raise TypeError(
    555       f"Cannot interpret value of type {type(x)} as an abstract array; it "
    556       "does not have a dtype attribute")
    557 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
    558                         named_shape=named_shape)

TypeError: Cannot interpret value of type <class 'numpy.random._generator.Generator'> as an abstract array; it does not have a dtype attribute

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/vm.py:414, in Loop.__call__(self)
    411 for thunk, node, old_storage in zip_longest(
    412     self.thunks, self.nodes, self.post_thunk_clear, fillvalue=()
    413 ):
--> 414     thunk()
    415     for old_s in old_storage:

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/op.py:1657, in Scan.make_thunk.<locals>.rval(p, i, o, n, allow_gc)
   1654 def rval(
   1655     p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
   1656 ):
-> 1657     r = p(n, [x[0] for x in i], o)
   1658     for o in node.outputs:

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/op.py:1918, in Scan.perform(self, node, inputs, output_storage, params)
   1917 try:
-> 1918     vm()
   1919 except Exception:

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:218, in streamline.<locals>.streamline_nice_errors_f()
    217 except Exception:
--> 218     raise_with_op(fgraph, node, thunk)

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    533     # Some exception need extra parameter in inputs. So forget the
    534     # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:216, in streamline.<locals>.streamline_nice_errors_f()
    215     for thunk, node in zip(thunks, order):
--> 216         thunk()
    217 except Exception:

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    663 def thunk(
    664     fgraph=self.fgraph,
    665     fgraph_jit=fgraph_jit,
    666     thunk_inputs=thunk_inputs,
    667     thunk_outputs=thunk_outputs,
    668 ):
--> 669     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    671     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):

    [... skipping hidden 6 frame]

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/jax/_src/api_util.py:554, in _shaped_abstractify_slow(x)
    553 else:
--> 554   raise TypeError(
    555       f"Cannot interpret value of type {type(x)} as an abstract array; it "
    556       "does not have a dtype attribute")
    557 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
    558                         named_shape=named_shape)

TypeError: Cannot interpret value of type <class 'numpy.random._generator.Generator'> as an abstract array; it does not have a dtype attribute
Apply node that caused the error: normal_rv{0, (0, 0), floatX, False}(*1-<RandomGeneratorType>, TensorConstant{[]}, TensorConstant{11}, *0-<TensorType(float64, ())>, *2-<TensorType(float64, ())>)
Toposort index: 0
Inputs types: [RandomGeneratorType, TensorType(int64, (0,)), TensorType(int64, ()), TensorType(float64, ()), TensorType(float64, ())]
Inputs shapes: [(), 'No shapes', ()]
Inputs strides: [(), 'No strides', ()]
Inputs values: [array(0.), Generator(PCG64) at 0x17E245FC0, array(1.)]
Outputs clients: [['output'], ['output']]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_32164/1022036140.py", line 22, in <module>
    traj, updates = pytensor.scan(step,
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/basic.py", line 852, in scan
    raw_inner_outputs = fn(*args)
  File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_32164/1022036140.py", line 19, in step
    x = pm.Normal.dist(mu=last_x, sigma=sigma)
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/distributions/continuous.py", line 520, in dist
    return super().dist([mu, sigma], **kwargs)
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 389, in dist
    rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[2], line 30
     28 model.register_rv(traj, name='traj', initval='prior')
     29 obs = pm.Normal('obs', mu=traj, sigma=eta, observed=test_obs)
---> 30 idata = pm.sample(nuts_sampler='numpyro')

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/sampling/mcmc.py:564, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    561         auto_nuts_init = False
    563 initial_points = None
--> 564 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    566 if nuts_sampler != "pymc":
    567     if not isinstance(step, NUTS):

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/sampling/mcmc.py:203, in assign_step_methods(model, step, methods, step_kwargs)
    195         selected = max(
    196             methods,
    197             key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence(
    198                 var, has_gradient
    199             ),
    200         )
    201         selected_steps[selected].append(var)
--> 203 return instantiate_steppers(model, steps, selected_steps, step_kwargs)

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/sampling/mcmc.py:116, in instantiate_steppers(model, steps, selected_steps, step_kwargs)
    114         args = step_kwargs.get(step_class.name, {})
    115         used_keys.add(step_class.name)
--> 116         step = step_class(vars=vars, model=model, **args)
    117         steps.append(step)
    119 unused_args = set(step_kwargs).difference(used_keys)

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/step_methods/hmc/nuts.py:180, in NUTS.__init__(self, vars, max_treedepth, early_max_treedepth, **kwargs)
    122 def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs):
    123     r"""Set up the No-U-Turn sampler.
    124 
    125     Parameters
   (...)
    178     `pm.sample` to the desired number of tuning steps.
    179     """
--> 180     super().__init__(vars, **kwargs)
    182     self.max_treedepth = max_treedepth
    183     self.early_max_treedepth = early_max_treedepth

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/step_methods/hmc/base_hmc.py:109, in BaseHMC.__init__(self, vars, scaling, step_scale, is_cov, model, blocked, potential, dtype, Emax, target_accept, gamma, k, t0, adapt_step_size, step_rand, **pytensor_kwargs)
    107 else:
    108     vars = get_value_vars_from_user_vars(vars, self._model)
--> 109 super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **pytensor_kwargs)
    111 self.adapt_step_size = adapt_step_size
    112 self.Emax = Emax

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/step_methods/arraystep.py:164, in GradientSharedStep.__init__(self, vars, model, blocked, dtype, logp_dlogp_func, **pytensor_kwargs)
    161 model = modelcontext(model)
    163 if logp_dlogp_func is None:
--> 164     func = model.logp_dlogp_function(vars, dtype=dtype, **pytensor_kwargs)
    165 else:
    166     func = logp_dlogp_func

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/model.py:649, in Model.logp_dlogp_function(self, grad_vars, tempered, **kwargs)
    646     costs = [self.logp()]
    648 input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
--> 649 ip = self.initial_point(0)
    650 extra_vars_and_values = {
    651     var: ip[var.name]
    652     for var in self.value_vars
    653     if var in input_vars and var not in grad_vars
    654 }
    655 return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/model.py:1133, in Model.initial_point(self, random_seed)
   1120 """Computes the initial point of the model.
   1121 
   1122 Parameters
   (...)
   1130     Maps names of transformed variables to numeric initial values in the transformed space.
   1131 """
   1132 fn = make_initial_point_fn(model=self, return_transformed=True)
-> 1133 return Point(fn(random_seed), model=self)

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/initial_point.py:169, in make_initial_point_fn.<locals>.make_seeded_function.<locals>.inner(seed, *args, **kwargs)
    166 @functools.wraps(func)
    167 def inner(seed, *args, **kwargs):
    168     reseed_rngs(rngs, seed)
--> 169     values = func(*args, **kwargs)
    170     return dict(zip(varnames, values))

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs)
    967 t0_fn = time.perf_counter()
    968 try:
    969     outputs = (
--> 970         self.vm()
    971         if output_subset is None
    972         else self.vm(output_subset=output_subset)
    973     )
    974 except Exception:
    975     restore_defaults()

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/vm.py:418, in Loop.__call__(self)
    416                 old_s[0] = None
    417     except Exception:
--> 418         raise_with_op(self.fgraph, node, thunk)
    420 return self.perform_updates()

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    530     warnings.warn(
    531         f"{exc_type} error does not allow us to add an extra error message"
    532     )
    533     # Some exception need extra parameter in inputs. So forget the
    534     # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/vm.py:414, in Loop.__call__(self)
    410 try:
    411     for thunk, node, old_storage in zip_longest(
    412         self.thunks, self.nodes, self.post_thunk_clear, fillvalue=()
    413     ):
--> 414         thunk()
    415         for old_s in old_storage:
    416             old_s[0] = None

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/op.py:1657, in Scan.make_thunk.<locals>.rval(p, i, o, n, allow_gc)
   1654 def rval(
   1655     p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
   1656 ):
-> 1657     r = p(n, [x[0] for x in i], o)
   1658     for o in node.outputs:
   1659         compute_map[o][0] = True

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/op.py:1918, in Scan.perform(self, node, inputs, output_storage, params)
   1915 t0_fn = time.perf_counter()
   1917 try:
-> 1918     vm()
   1919 except Exception:
   1920     if hasattr(vm, "position_of_error"):
   1921         # this is a new vm-provided function or c linker
   1922         # they need this because the exception manipulation
   1923         # done by raise_with_op is not implemented in C.

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:218, in streamline.<locals>.streamline_nice_errors_f()
    216         thunk()
    217 except Exception:
--> 218     raise_with_op(fgraph, node, thunk)

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:535, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    530     warnings.warn(
    531         f"{exc_type} error does not allow us to add an extra error message"
    532     )
    533     # Some exception need extra parameter in inputs. So forget the
    534     # extra long error message in that case.
--> 535 raise exc_value.with_traceback(exc_trace)

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/utils.py:216, in streamline.<locals>.streamline_nice_errors_f()
    214 try:
    215     for thunk, node in zip(thunks, order):
--> 216         thunk()
    217 except Exception:
    218     raise_with_op(fgraph, node, thunk)

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/link/basic.py:669, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs)
    663 def thunk(
    664     fgraph=self.fgraph,
    665     fgraph_jit=fgraph_jit,
    666     thunk_inputs=thunk_inputs,
    667     thunk_outputs=thunk_outputs,
    668 ):
--> 669     outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
    671     for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
    672         compute_map[o_var][0] = True

    [... skipping hidden 6 frame]

File ~/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/jax/_src/api_util.py:554, in _shaped_abstractify_slow(x)
    552   dtype = dtypes.canonicalize_dtype(x.dtype, allow_opaque_dtype=True)
    553 else:
--> 554   raise TypeError(
    555       f"Cannot interpret value of type {type(x)} as an abstract array; it "
    556       "does not have a dtype attribute")
    557 return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type,
    558                         named_shape=named_shape)

TypeError: Cannot interpret value of type <class 'numpy.random._generator.Generator'> as an abstract array; it does not have a dtype attribute
Apply node that caused the error: normal_rv{0, (0, 0), floatX, False}(*1-<RandomGeneratorType>, TensorConstant{[]}, TensorConstant{11}, *0-<TensorType(float64, ())>, *2-<TensorType(float64, ())>)
Toposort index: 0
Inputs types: [RandomGeneratorType, TensorType(int64, (0,)), TensorType(int64, ()), TensorType(float64, ()), TensorType(float64, ())]
Inputs shapes: [(), 'No shapes', ()]
Inputs strides: [(), 'No strides', ()]
Inputs values: [array(0.), Generator(PCG64) at 0x17E245FC0, array(1.)]
Outputs clients: [['output'], ['output']]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_32164/1022036140.py", line 22, in <module>
    traj, updates = pytensor.scan(step,
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pytensor/scan/basic.py", line 852, in scan
    raw_inner_outputs = fn(*args)
  File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_32164/1022036140.py", line 19, in step
    x = pm.Normal.dist(mu=last_x, sigma=sigma)
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/distributions/continuous.py", line 520, in dist
    return super().dist([mu, sigma], **kwargs)
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 389, in dist
    rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
Apply node that caused the error: for{cpu,scan_fn}(TensorConstant{100}, IncSubtensor{Set;:int64:}.0, RandomGeneratorSharedVariable(<Generator(PCG64) at 0x17E245FC0>), TensorConstant{1.0})
Toposort index: 5
Inputs types: [TensorType(int8, ()), TensorType(float64, (101,)), RandomGeneratorType, TensorType(float64, ())]
Inputs shapes: [(), (101,), 'No shapes', ()]
Inputs strides: [(), (8,), 'No strides', ()]
Inputs values: [array(100, dtype=int8), 'not shown', Generator(PCG64) at 0x17E245FC0, array(1.)]
Outputs clients: [[Subtensor{int64::}(for{cpu,scan_fn}.0, ScalarConstant{1})], []]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 540, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3009, in run_cell
    result = self._run_cell(
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3064, in _run_cell
    result = runner(coro)
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/Users/jessegrabowski/mambaforge/envs/ukraine-sentiment/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_32164/1022036140.py", line 22, in <module>
    traj, updates = pytensor.scan(step,

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

PyMC version information:

PyMC: 5.3.0
Pytensor: 2.11.1

@ricardoV94
Copy link
Member

Possibly related to #6351

The issue is that there is not a 1-to-1 map between the Scan RV and the Scan value variable (due to the weird output of Scan actually being a Slice I think)

@jessegrabowski
Copy link
Member Author

jessegrabowski commented May 15, 2023

It works with the default backend just fine, I was hoping it was something to do with how the rng was being passed around in the graph during JAX compilation, since that's what the error complains about

@ricardoV94
Copy link
Member

Ah okay, that sounds different then

@ricardoV94
Copy link
Member

Regardless of this issue, we should definitely clean up the Scan mode thing. It should use whatever mode is being used by the outer function I think

@jessegrabowski
Copy link
Member Author

How complex of a fix would that be?

@ricardoV94
Copy link
Member

By the way this seems to be triggered by the initial_point function, probably related to the way it seeds things. You can trigger the failure with just model.initial_point()

@ricardoV94
Copy link
Member

How complex of a fix would that be?

I don't quite know, but worth a look. An option is to grab the user provided mode and exclude rewrites that are incompatible with JAX (since we know then that we are compiling to JAX. Otherwise we could have an optional kwarg to the dispatch function with the mode that is provided by the JITLinker

@ricardoV94
Copy link
Member

ricardoV94 commented May 15, 2023

Yeah it is trying to feed a numpy generator as input. model.initial_point() is creating a C function but the inner JAX function is being compiled to JAX. I think the solution is indeed to fix the mode thing. This happens in: https://github.com/pymc-devs/pytensor/blob/9ae07ab03bf417bd1c703ec624f494250621e7af/pytensor/link/jax/dispatch/scan.py#L20-L23

This would also fix #6697 which would be a big improvement.

@ricardoV94
Copy link
Member

ricardoV94 commented May 15, 2023

An immediate solution to your problem is to pass a valid initval to model.register_rv

model.register_rv(traj, name='traj', initval=np.zeros(100))

@jessegrabowski
Copy link
Member Author

Good to know! I can try to have a look at the mode problem as well over the next couple days if you're busy with other stuff.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants