Skip to content

Commit db3087a

Browse files
authored
Merge pull request #781 from stan-dev/cleanup-tqdm
Clean up TQDM usage
2 parents 797fa6e + a78b980 commit db3087a

File tree

1 file changed

+15
-30
lines changed

1 file changed

+15
-30
lines changed

cmdstanpy/model.py

+15-30
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,7 @@
2929
import pandas as pd
3030
from tqdm.auto import tqdm
3131

32-
from cmdstanpy import (
33-
_CMDSTAN_REFRESH,
34-
_CMDSTAN_SAMPLING,
35-
_CMDSTAN_WARMUP,
36-
_TMPDIR,
37-
compilation,
38-
)
32+
from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_WARMUP, _TMPDIR, compilation
3933
from cmdstanpy.cmdstan_args import (
4034
CmdStanArgs,
4135
GenerateQuantitiesArgs,
@@ -1069,9 +1063,6 @@ def sample(
10691063
iter_total += _CMDSTAN_SAMPLING
10701064
else:
10711065
iter_total += iter_sampling
1072-
if refresh is None:
1073-
refresh = _CMDSTAN_REFRESH
1074-
iter_total = iter_total // refresh + 2
10751066

10761067
progress_hook = self._wrap_sampler_progress_hook(
10771068
chain_ids=chain_ids,
@@ -2138,13 +2129,12 @@ def _wrap_sampler_progress_hook(
21382129
process, "Chain [id] Iteration" for multi-chain processing.
21392130
For the latter, manage array of pbars, update accordingly.
21402131
"""
2141-
pat = re.compile(r'Chain \[(\d*)\] (Iteration.*)')
2132+
chain_pat = re.compile(r'(Chain \[(\d+)\] )?Iteration:\s+(\d+)')
21422133
pbars: Dict[int, tqdm] = {
21432134
chain_id: tqdm(
21442135
total=total,
2145-
bar_format="{desc} |{bar}| {elapsed} {postfix[0][value]}",
2146-
postfix=[{"value": "Status"}],
21472136
desc=f'chain {chain_id}',
2137+
postfix='(Warmup)',
21482138
colour='yellow',
21492139
)
21502140
for chain_id in chain_ids
@@ -2153,23 +2143,19 @@ def _wrap_sampler_progress_hook(
21532143
def progress_hook(line: str, idx: int) -> None:
21542144
if line == "Done":
21552145
for pbar in pbars.values():
2156-
pbar.postfix[0]["value"] = 'Sampling completed'
2146+
pbar.set_postfix_str('(Sampling completed)')
21572147
pbar.update(total - pbar.n)
21582148
pbar.close()
2159-
else:
2160-
match = pat.match(line)
2161-
if match:
2162-
idx = int(match.group(1))
2163-
mline = match.group(2).strip()
2164-
elif line.startswith("Iteration"):
2165-
mline = line
2166-
idx = chain_ids[idx]
2167-
else:
2168-
return
2169-
if 'Sampling' in mline:
2170-
pbars[idx].colour = 'blue'
2171-
pbars[idx].update(1)
2172-
pbars[idx].postfix[0]["value"] = mline
2149+
elif (match := chain_pat.match(line)) is not None:
2150+
idx = int(match.group(2) or chain_ids[idx])
2151+
current_iter = int(match.group(3))
2152+
2153+
pbar = pbars[idx]
2154+
if pbar.colour == 'yellow' and 'Sampling' in line:
2155+
pbar.colour = 'blue'
2156+
pbar.set_postfix_str('(Sampling)')
2157+
2158+
pbar.update(current_iter - pbar.n)
21732159

21742160
return progress_hook
21752161

@@ -2225,8 +2211,7 @@ def diagnose(
22252211
Gradients are evaluated in the unconstrained space.
22262212
"""
22272213

2228-
with temp_single_json(data) as _data, \
2229-
temp_single_json(inits) as _inits:
2214+
with temp_single_json(data) as _data, temp_single_json(inits) as _inits:
22302215
cmd = [
22312216
str(self.exe_file),
22322217
"diagnose",

0 commit comments

Comments
 (0)