33
44The only PyMC dependency is on the ``BaseTrace`` abstract base class.
55"""
6+ import base64
7+ import pickle
68from typing import Dict , List , Optional , Sequence , Tuple
79
810import hagelkorn
@@ -159,12 +161,16 @@ def setup(
159161 self ._stat_groups .append ([])
160162 for statname , dtype in names_dtypes .items ():
161163 sname = f"sampler_{ s } __{ statname } "
162- svar = Variable (
163- name = sname ,
164- dtype = numpy .dtype (dtype ).name ,
165- # This 👇 is needed until PyMC provides shapes ahead of time.
166- undefined_ndim = True ,
167- )
164+ if statname == "warning" :
165+ # SamplerWarnings will be pickled and stored as string!
166+ svar = Variable (sname , "str" )
167+ else :
168+ svar = Variable (
169+ name = sname ,
170+ dtype = numpy .dtype (dtype ).name ,
171+ # This 👇 is needed until PyMC provides shapes ahead of time.
172+ undefined_ndim = True ,
173+ )
168174 self ._stat_groups [s ].append ((sname , statname ))
169175 sample_stats .append (svar )
170176
@@ -197,8 +203,12 @@ def record(self, point, sampler_states=None):
197203 for s , sts in enumerate (sampler_states ):
198204 for statname , sval in sts .items ():
199205 sname = f"sampler_{ s } __{ statname } "
200- stats [sname ] = sval
201- # Make not whether this is a tuning iteration.
206+ # Automatically pickle SamplerWarnings
207+ if statname == "warning" :
208+ sval_bytes = pickle .dumps (sval )
209+ sval = base64 .encodebytes (sval_bytes ).decode ("ascii" )
210+ stats [sname ] = numpy .asarray (sval )
211+ # Make note whether this is a tuning iteration.
202212 if statname == "tune" :
203213 stats ["tune" ] = sval
204214
@@ -214,7 +224,16 @@ def get_values(self, varname, burn=0, thin=1) -> numpy.ndarray:
214224 def _get_stats (self , varname , burn = 0 , thin = 1 ) -> numpy .ndarray :
215225 if self ._chain is None :
216226 raise Exception ("Trace setup was not completed. Call `.setup()` first." )
217- return self ._chain .get_stats (varname )[burn ::thin ]
227+ values = self ._chain .get_stats (varname )[burn ::thin ]
228+ if "warning" in varname :
229+ objs = []
230+ for v in values :
231+ enc = v .encode ("ascii" )
232+ str_ = base64 .decodebytes (enc )
233+ obj = pickle .loads (str_ )
234+ objs .append (obj )
235+ values = numpy .array (objs , dtype = object )
236+ return values
218237
219238 def _get_sampler_stats (self , stat_name , sampler_idx , burn , thin ):
220239 if self ._chain is None :
0 commit comments