29
29
import pandas as pd
30
30
from tqdm .auto import tqdm
31
31
32
- from cmdstanpy import (
33
- _CMDSTAN_REFRESH ,
34
- _CMDSTAN_SAMPLING ,
35
- _CMDSTAN_WARMUP ,
36
- _TMPDIR ,
37
- compilation ,
38
- )
32
+ from cmdstanpy import _CMDSTAN_SAMPLING , _CMDSTAN_WARMUP , _TMPDIR , compilation
39
33
from cmdstanpy .cmdstan_args import (
40
34
CmdStanArgs ,
41
35
GenerateQuantitiesArgs ,
@@ -1069,9 +1063,6 @@ def sample(
1069
1063
iter_total += _CMDSTAN_SAMPLING
1070
1064
else :
1071
1065
iter_total += iter_sampling
1072
- if refresh is None :
1073
- refresh = _CMDSTAN_REFRESH
1074
- iter_total = iter_total // refresh + 2
1075
1066
1076
1067
progress_hook = self ._wrap_sampler_progress_hook (
1077
1068
chain_ids = chain_ids ,
@@ -2138,13 +2129,12 @@ def _wrap_sampler_progress_hook(
2138
2129
process, "Chain [id] Iteration" for multi-chain processing.
2139
2130
For the latter, manage array of pbars, update accordingly.
2140
2131
"""
2141
- pat = re .compile (r'Chain \[(\d* )\] ( Iteration.* )' )
2132
+ chain_pat = re .compile (r'( Chain \[(\d+ )\] )? Iteration:\s+(\d+ )' )
2142
2133
pbars : Dict [int , tqdm ] = {
2143
2134
chain_id : tqdm (
2144
2135
total = total ,
2145
- bar_format = "{desc} |{bar}| {elapsed} {postfix[0][value]}" ,
2146
- postfix = [{"value" : "Status" }],
2147
2136
desc = f'chain { chain_id } ' ,
2137
+ postfix = '(Warmup)' ,
2148
2138
colour = 'yellow' ,
2149
2139
)
2150
2140
for chain_id in chain_ids
@@ -2153,23 +2143,19 @@ def _wrap_sampler_progress_hook(
2153
2143
def progress_hook (line : str , idx : int ) -> None :
2154
2144
if line == "Done" :
2155
2145
for pbar in pbars .values ():
2156
- pbar .postfix [ 0 ][ "value" ] = ' Sampling completed'
2146
+ pbar .set_postfix_str ( '( Sampling completed)' )
2157
2147
pbar .update (total - pbar .n )
2158
2148
pbar .close ()
2159
- else :
2160
- match = pat .match (line )
2161
- if match :
2162
- idx = int (match .group (1 ))
2163
- mline = match .group (2 ).strip ()
2164
- elif line .startswith ("Iteration" ):
2165
- mline = line
2166
- idx = chain_ids [idx ]
2167
- else :
2168
- return
2169
- if 'Sampling' in mline :
2170
- pbars [idx ].colour = 'blue'
2171
- pbars [idx ].update (1 )
2172
- pbars [idx ].postfix [0 ]["value" ] = mline
2149
+ elif (match := chain_pat .match (line )) is not None :
2150
+ idx = int (match .group (2 ) or chain_ids [idx ])
2151
+ current_iter = int (match .group (3 ))
2152
+
2153
+ pbar = pbars [idx ]
2154
+ if pbar .colour == 'yellow' and 'Sampling' in line :
2155
+ pbar .colour = 'blue'
2156
+ pbar .set_postfix_str ('(Sampling)' )
2157
+
2158
+ pbar .update (current_iter - pbar .n )
2173
2159
2174
2160
return progress_hook
2175
2161
@@ -2225,8 +2211,7 @@ def diagnose(
2225
2211
Gradients are evaluated in the unconstrained space.
2226
2212
"""
2227
2213
2228
- with temp_single_json (data ) as _data , \
2229
- temp_single_json (inits ) as _inits :
2214
+ with temp_single_json (data ) as _data , temp_single_json (inits ) as _inits :
2230
2215
cmd = [
2231
2216
str (self .exe_file ),
2232
2217
"diagnose" ,
0 commit comments