Open
Description
Describe the issue:
Sampling with observations containing missing values is failing in 5.20. Metropolis
raises a TypeError
when updating the sampling metadata. This does not occur when Metropolis
is used for the entire model.
It appears that stats['tune']
ends up being a string in this scenario instead of an array.
Reproduceable code example:
import pymc as pm
import numpy as np
disasters_missing = np.array([ 4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])
N = len(disasters_missing)
with pm.Model() as missing_data_model:
# Prior for distribution of switchpoint location
switchpoint = pm.DiscreteUniform('switchpoint', lower=0, upper=N)
# Priors for pre- and post-switch mean number of disasters
early_mean = pm.Exponential('early_mean', lam=1.)
late_mean = pm.Exponential('late_mean', lam=1.)
# Allocate appropriate Poisson rates to years before and after current
# switchpoint location
idx = np.arange(N)
rate = pm.math.switch(switchpoint >= idx, early_mean, late_mean)
# Data likelihood
disasters = pm.Poisson('disasters', rate, observed=disasters_missing)
trace_missing = pm.sample()
Error message:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[60], line 2
1 with missing_data_model:
----> 2 trace_missing = pm.sample()
File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:935, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
933 _print_step_hierarchy(step)
934 try:
--> 935 _mp_sample(**sample_args, **parallel_args)
936 except pickle.PickleError:
937 _log.warning("Could not pickle model, sampling singlethreaded.")
File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1411, in _mp_sample(draws, tune, step, chains, cores, rngs, start, progressbar, progressbar_theme, traces, model, callback, blas_cores, mp_ctx, **kwargs)
1409 try:
1410 with sampler:
-> 1411 for draw in sampler:
1412 strace = traces[draw.chain]
1413 if not zarr_recording:
1414 # Zarr recording happens in each process
File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/parallel.py:513, in ParallelSampler.__iter__(self)
510 draw = ProcessAdapter.recv_draw(self._active)
511 proc, is_last, draw, tuning, stats = draw
--> 513 self._progress.update(
514 chain_idx=proc.chain, is_last=is_last, draw=draw, tuning=tuning, stats=stats
515 )
517 if is_last:
518 proc.join()
File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/util.py:886, in ProgressBarManager.update(self, chain_idx, is_last, draw, tuning, stats)
883 if not tuning and stats and stats[0].get("diverging"):
884 self.divergences += 1
--> 886 self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx)
887 more_updates = (
888 {stat: value[chain_idx] for stat, value in self.progress_stats.items()}
889 if self.full_stats
890 else {}
891 )
893 self._progress.update(
894 self.tasks[chain_idx],
895 completed=draw,
(...)
899 **more_updates,
900 )
File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/step_methods/compound.py:340, in CompoundStep._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
338 def update_stats(stats, step_stats, chain_idx):
339 for step_stat, update_fn in zip(step_stats, update_fns):
--> 340 stats = update_fn(stats, step_stat, chain_idx)
342 return stats
File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/step_methods/compound.py:340, in CompoundStep._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
338 def update_stats(stats, step_stats, chain_idx):
339 for step_stat, update_fn in zip(step_stats, update_fns):
--> 340 stats = update_fn(stats, step_stat, chain_idx)
342 return stats
File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/step_methods/metropolis.py:354, in Metropolis._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
351 if isinstance(step_stats, list):
352 step_stats = step_stats[0]
--> 354 stats["tune"][chain_idx] = step_stats["tune"]
355 stats["accept_rate"][chain_idx] = step_stats["accept"]
356 stats["scaling"][chain_idx] = step_stats["scaling"]
TypeError: string indices must be integers, not 'str'
PyMC version information:
pymc : 5.20.1
pytensor : 2.27.1
Context for the issue:
No response