4040
4141
4242class ParallelSamplingError (Exception ):
43- def __init__ (self , message , chain , warnings = None ):
43+ def __init__ (self , message , chain ):
4444 super ().__init__ (message )
45- if warnings is None :
46- warnings = []
4745 self ._chain = chain
48- self ._warnings = warnings
4946
5047
5148# Taken from https://hg.python.org/cpython/rev/c4f92b597074
@@ -74,8 +71,8 @@ def rebuild_exc(exc, tb):
7471
7572
7673# Messages
77- # ('writing_done', is_last, sample_idx, tuning, stats, warns )
78- # ('error', warnings, *exception_info)
74+ # ('writing_done', is_last, sample_idx, tuning, stats)
75+ # ('error', *exception_info)
7976
8077# ('abort', reason)
8178# ('write_next',)
@@ -133,7 +130,7 @@ def run(self):
133130 e = ExceptionWithTraceback (e , e .__traceback__ )
134131 # Send is not blocking so we have to force a wait for the abort
135132 # message
136- self ._msg_pipe .send (("error" , None , e ))
133+ self ._msg_pipe .send (("error" , e ))
137134 self ._wait_for_abortion ()
138135 finally :
139136 self ._msg_pipe .close ()
@@ -181,9 +178,8 @@ def _start_loop(self):
181178 try :
182179 point , stats = self ._compute_point ()
183180 except SamplingError as e :
184- warns = self ._collect_warnings ()
185181 e = ExceptionWithTraceback (e , e .__traceback__ )
186- self ._msg_pipe .send (("error" , warns , e ))
182+ self ._msg_pipe .send (("error" , e ))
187183 else :
188184 return
189185
@@ -193,11 +189,7 @@ def _start_loop(self):
193189 elif msg [0 ] == "write_next" :
194190 self ._write_point (point )
195191 is_last = draw + 1 == self ._draws + self ._tune
196- if is_last :
197- warns = self ._collect_warnings ()
198- else :
199- warns = None
200- self ._msg_pipe .send (("writing_done" , is_last , draw , tuning , stats , warns ))
192+ self ._msg_pipe .send (("writing_done" , is_last , draw , tuning , stats ))
201193 draw += 1
202194 else :
203195 raise ValueError ("Unknown message " + msg [0 ])
@@ -210,12 +202,6 @@ def _compute_point(self):
210202 stats = None
211203 return point , stats
212204
213- def _collect_warnings (self ):
214- if hasattr (self ._step_method , "warnings" ):
215- return self ._step_method .warnings ()
216- else :
217- return []
218-
219205
220206def _run_process (* args ):
221207 _Process (* args ).run ()
@@ -308,11 +294,13 @@ def _send(self, msg, *args):
308294 except Exception :
309295 pass
310296 if message is not None and message [0 ] == "error" :
311- warns , old_error = message [1 :]
312- if warns is not None :
313- error = ParallelSamplingError (str (old_error ), self .chain , warns )
297+ old_error = message [1 ]
298+ if old_error is not None :
299+ error = ParallelSamplingError (
300+ f"Chain { self .chain } failed with: { old_error } " , self .chain
301+ )
314302 else :
315- error = RuntimeError ("Chain %s failed." % self . chain )
303+ error = RuntimeError (f "Chain { self . chain } failed." )
316304 raise error from old_error
317305 raise
318306
@@ -345,11 +333,13 @@ def recv_draw(processes, timeout=3600):
345333 msg = ready [0 ].recv ()
346334
347335 if msg [0 ] == "error" :
348- warns , old_error = msg [1 :]
349- if warns is not None :
350- error = ParallelSamplingError (str (old_error ), proc .chain , warns )
336+ old_error = msg [1 ]
337+ if old_error is not None :
338+ error = ParallelSamplingError (
339+ f"Chain { proc .chain } failed with: { old_error } " , proc .chain
340+ )
351341 else :
352- error = RuntimeError ("Chain %s failed." % proc . chain )
342+ error = RuntimeError (f "Chain { proc . chain } failed." )
353343 raise error from old_error
354344 elif msg [0 ] == "writing_done" :
355345 proc ._readable = True
@@ -383,7 +373,7 @@ def terminate_all(processes, patience=2):
383373 process .join ()
384374
385375
386- Draw = namedtuple ("Draw" , ["chain" , "is_last" , "draw_idx" , "tuning" , "stats" , "point" , "warnings" ])
376+ Draw = namedtuple ("Draw" , ["chain" , "is_last" , "draw_idx" , "tuning" , "stats" , "point" ])
387377
388378
389379class ParallelSampler :
@@ -461,7 +451,7 @@ def __iter__(self):
461451
462452 while self ._active :
463453 draw = ProcessAdapter .recv_draw (self ._active )
464- proc , is_last , draw , tuning , stats , warns = draw
454+ proc , is_last , draw , tuning , stats = draw
465455 self ._total_draws += 1
466456 if not tuning and stats and stats [0 ].get ("diverging" ):
467457 self ._divergences += 1
@@ -486,7 +476,7 @@ def __iter__(self):
486476 if not is_last :
487477 proc .write_next ()
488478
489- yield Draw (proc .chain , is_last , draw , tuning , stats , point , warns )
479+ yield Draw (proc .chain , is_last , draw , tuning , stats , point )
490480
491481 def __enter__ (self ):
492482 self ._in_context = True
0 commit comments