Skip to content

BUG: Indexing TypeError when sampling with missing observations #7724

Open
@fonnesbeck

Description

@fonnesbeck

Describe the issue:

Sampling with observations containing missing values is failing in 5.20. Metropolis raises a TypeError when updating the sampling metadata. This does not occur when Metropolis is used for the entire model.

It appears that stats['tune'] ends up being a string in this scenario instead of an array.

Reproduceable code example:

import pymc as pm
import numpy as np

disasters_missing = np.array([ 4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
2, 2, 3, 4, 2, 1, 3, np.nan, 2, 1, 1, 1, 1, 3, 0, 0,
1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
3, 3, 1, np.nan, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])

N = len(disasters_missing)

with pm.Model() as missing_data_model:

    # Prior for distribution of switchpoint location
    switchpoint = pm.DiscreteUniform('switchpoint', lower=0, upper=N)
    # Priors for pre- and post-switch mean number of disasters
    early_mean = pm.Exponential('early_mean', lam=1.)
    late_mean = pm.Exponential('late_mean', lam=1.)

    # Allocate appropriate Poisson rates to years before and after current
    # switchpoint location
    idx = np.arange(N)
    rate = pm.math.switch(switchpoint >= idx, early_mean, late_mean)

    # Data likelihood
    disasters = pm.Poisson('disasters', rate, observed=disasters_missing)

    trace_missing = pm.sample()

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[60], line 2
      1 with missing_data_model:
----> 2     trace_missing = pm.sample()

File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:935, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    933 _print_step_hierarchy(step)
    934 try:
--> 935     _mp_sample(**sample_args, **parallel_args)
    936 except pickle.PickleError:
    937     _log.warning("Could not pickle model, sampling singlethreaded.")

File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1411, in _mp_sample(draws, tune, step, chains, cores, rngs, start, progressbar, progressbar_theme, traces, model, callback, blas_cores, mp_ctx, **kwargs)
   1409 try:
   1410     with sampler:
-> 1411         for draw in sampler:
   1412             strace = traces[draw.chain]
   1413             if not zarr_recording:
   1414                 # Zarr recording happens in each process

File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/parallel.py:513, in ParallelSampler.__iter__(self)
    510 draw = ProcessAdapter.recv_draw(self._active)
    511 proc, is_last, draw, tuning, stats = draw
--> 513 self._progress.update(
    514     chain_idx=proc.chain, is_last=is_last, draw=draw, tuning=tuning, stats=stats
    515 )
    517 if is_last:
    518     proc.join()

File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/util.py:886, in ProgressBarManager.update(self, chain_idx, is_last, draw, tuning, stats)
    883 if not tuning and stats and stats[0].get("diverging"):
    884     self.divergences += 1
--> 886 self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx)
    887 more_updates = (
    888     {stat: value[chain_idx] for stat, value in self.progress_stats.items()}
    889     if self.full_stats
    890     else {}
    891 )
    893 self._progress.update(
    894     self.tasks[chain_idx],
    895     completed=draw,
   (...)
    899     **more_updates,
    900 )

File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/step_methods/compound.py:340, in CompoundStep._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    338 def update_stats(stats, step_stats, chain_idx):
    339     for step_stat, update_fn in zip(step_stats, update_fns):
--> 340         stats = update_fn(stats, step_stat, chain_idx)
    342     return stats

File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/step_methods/compound.py:340, in CompoundStep._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    338 def update_stats(stats, step_stats, chain_idx):
    339     for step_stat, update_fn in zip(step_stats, update_fns):
--> 340         stats = update_fn(stats, step_stat, chain_idx)
    342     return stats

File ~/labs/ccc-workshop/.pixi/envs/default/lib/python3.12/site-packages/pymc/step_methods/metropolis.py:354, in Metropolis._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    351 if isinstance(step_stats, list):
    352     step_stats = step_stats[0]
--> 354 stats["tune"][chain_idx] = step_stats["tune"]
    355 stats["accept_rate"][chain_idx] = step_stats["accept"]
    356 stats["scaling"][chain_idx] = step_stats["scaling"]

TypeError: string indices must be integers, not 'str'

PyMC version information:

pymc : 5.20.1
pytensor : 2.27.1

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions