Skip to content

Commit

Permalink
Remove some 'refresh' and make sure progress goes to 100%
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed May 17, 2024
1 parent d245a01 commit 292dae9
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 12 deletions.
20 changes: 16 additions & 4 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from pytensor.tensor.sharedvar import SharedVariable
from rich.console import Console
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme

import pymc as pm
Expand Down Expand Up @@ -828,11 +829,21 @@ def sample_posterior_predictive(
# All model variables have a name, but mypy does not know this
_log.info(f"Sampling: {list(sorted(volatile_basic_rvs, key=lambda var: var.name))}") # type: ignore
ppc_trace_t = _DefaultTrace(samples)

progress = CustomProgress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)

try:
with CustomProgress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:
task = progress.add_task("Sampling ...", total=samples)
with progress:
task = progress.add_task("Sampling ...", completed=0, total=samples)
for idx in np.arange(samples):
if nchain > 1:
# the trace object will either be a MultiTrace (and have _straces)...
Expand All @@ -854,6 +865,7 @@ def sample_posterior_predictive(
ppc_trace_t.insert(k.name, v, idx)

progress.advance(task)
progress.update(task, refresh=True, completed=samples)

except KeyboardInterrupt:
pass
Expand Down
25 changes: 19 additions & 6 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from arviz.data.base import make_attrs
from pytensor.graph.basic import Variable
from rich.console import Console
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.theme import Theme
from threadpoolctl import threadpool_limits
from typing_extensions import Protocol
Expand Down Expand Up @@ -1075,16 +1076,28 @@ def _sample(
)
_pbar_data = {"chain": chain, "divergences": 0}
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
with CustomProgress(
console=Console(theme=progressbar_theme), disable=not progressbar
) as progress:

progress = CustomProgress(
"[progress.description]{task.description}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}%",
TimeRemainingColumn(),
TextColumn("/"),
TimeElapsedColumn(),
console=Console(theme=progressbar_theme),
disable=not progressbar,
)

with progress:
try:
task = progress.add_task(_desc.format(**_pbar_data), total=draws)
task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws)
for it, diverging in enumerate(sampling_gen):
if it >= skip_first and diverging:
_pbar_data["divergences"] += 1
progress.update(task)
progress.update(task, refresh=True, advance=1, completed=True)
progress.update(task, description=_desc.format(**_pbar_data), completed=it)
progress.update(
task, description=_desc.format(**_pbar_data), completed=draws, refresh=True
)
except KeyboardInterrupt:
pass

Expand Down
1 change: 0 additions & 1 deletion pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def __iter__(self):
self._completed_draws += 1
if not tuning and stats and stats[0].get("diverging"):
self._divergences += 1

progress.update(
task,
completed=self._completed_draws,
Expand Down
1 change: 0 additions & 1 deletion pymc/sampling/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def _sample_population(

with CustomProgress(disable=not progressbar) as progress:
task = progress.add_task("[red]Sampling...", total=draws)

for _ in sampling:
progress.update(task)

Expand Down

0 comments on commit 292dae9

Please sign in to comment.