Skip to content

BUG: Progress bar throws error when nested CompoundSteps are present. #7721

Open
@fruzti

Description

@fruzti

Describe the issue:

Progress bar throws an error when a nested CompoundStep is found in the sampling flow of a model. Once the progress bar is deactivated, i..e, progressbar=False, the error is not anymore present.

Reproduceable code example:

with pm.Model() as modeWithErros:

    a   = pm.Poisson("a",mu=10)

    b   = pm.Binomial("b", n=a, p=0.8)

    c   = pm.Poisson("c",mu=11)

    d   = pm.Dirichlet("d",a=pt.stack([c,b]))

    pm.sample(draws=1000,tune=1000,chains=4)

Error message:

Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>CompoundStep
>>Metropolis: [a]
>>Metropolis: [b]
>>Metropolis: [c]
>NUTS: [d]

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[250], line 11
      7 c   = pm.Poisson("c",mu=11)
      9 d   = pm.Dirichlet("d",a=pt.stack([c,b]))
---> 11 pm.sample(draws=1000,tune=1000,chains=4)

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\sampling\mcmc.py:935, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    933 _print_step_hierarchy(step)
    934 try:
--> 935     _mp_sample(**sample_args, **parallel_args)
    936 except pickle.PickleError:
    937     _log.warning("Could not pickle model, sampling singlethreaded.")

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\sampling\mcmc.py:1411, in _mp_sample(draws, tune, step, chains, cores, rngs, start, progressbar, progressbar_theme, traces, model, callback, blas_cores, mp_ctx, **kwargs)
   1409 try:
   1410     with sampler:
-> 1411         for draw in sampler:
   1412             strace = traces[draw.chain]
   1413             if not zarr_recording:
   1414                 # Zarr recording happens in each process

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\sampling\parallel.py:513, in ParallelSampler.__iter__(self)
    510 draw = ProcessAdapter.recv_draw(self._active)
    511 proc, is_last, draw, tuning, stats = draw
--> 513 self._progress.update(
    514     chain_idx=proc.chain, is_last=is_last, draw=draw, tuning=tuning, stats=stats
    515 )
    517 if is_last:
    518     proc.join()

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\util.py:886, in ProgressBarManager.update(self, chain_idx, is_last, draw, tuning, stats)
    883 if not tuning and stats and stats[0].get("diverging"):
    884     self.divergences += 1
--> 886 self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx)
    887 more_updates = (
    888     {stat: value[chain_idx] for stat, value in self.progress_stats.items()}
    889     if self.full_stats
    890     else {}
    891 )
    893 self._progress.update(
    894     self.tasks[chain_idx],
    895     completed=draw,
   (...)
    899     **more_updates,
    900 )

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\step_methods\compound.py:340, in CompoundStep._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    338 def update_stats(stats, step_stats, chain_idx):
    339     for step_stat, update_fn in zip(step_stats, update_fns):
--> 340         stats = update_fn(stats, step_stat, chain_idx)
    342     return stats

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\step_methods\compound.py:340, in CompoundStep._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    338 def update_stats(stats, step_stats, chain_idx):
    339     for step_stat, update_fn in zip(step_stats, update_fns):
--> 340         stats = update_fn(stats, step_stat, chain_idx)
    342     return stats

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\step_methods\metropolis.py:354, in Metropolis._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    351 if isinstance(step_stats, list):
    352     step_stats = step_stats[0]
--> 354 stats["tune"][chain_idx] = step_stats["tune"]
    355 stats["accept_rate"][chain_idx] = step_stats["accept"]
    356 stats["scaling"][chain_idx] = step_stats["scaling"]

TypeError: string indices must be integers, not 'str'

PyMC version information:

pymc 5.21.0

Context for the issue:

Given that it is only when the progressbar is active, it is seems to not be urgent. Also, people should be able to find the workaround here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions