@@ -277,11 +277,12 @@ def test__get_chains(self):
277277 assert len (chain ) == 1
278278 pass
279279
280- def test__to_inferencedata (self ):
280+ @pytest .mark .parametrize ("tstatname" , ["tune" , "sampler__tune" , "nottune" ])
281+ def test__to_inferencedata (self , tstatname , caplog ):
281282 rmeta = make_runmeta (
282283 flexibility = False ,
283284 sample_stats = [
284- Variable ("tune" , "bool" ),
285+ Variable (tstatname , "bool" ),
285286 Variable ("sampler_0__logp" , "float32" ),
286287 Variable ("warning" , "str" ),
287288 ],
@@ -294,15 +295,22 @@ def test__to_inferencedata(self):
294295 draws = [make_draw (rmeta .variables ) for _ in range (n )]
295296 stats = [make_draw (rmeta .sample_stats ) for _ in range (n )]
296297 for i , (d , s ) in enumerate (zip (draws , stats )):
297- s ["tune" ] = i < 4
298+ s [tstatname ] = i < 4
298299 chain .append (d , s )
299300
300301 idata = run .to_inferencedata ()
301302 assert isinstance (idata , arviz .InferenceData )
302303 assert idata .warmup_posterior .dims ["chain" ] == 1
303- assert idata .warmup_posterior .dims ["draw" ] == 4
304304 assert idata .posterior .dims ["chain" ] == 1
305- assert idata .posterior .dims ["draw" ] == 6
305+ if tstatname == "nottune" :
306+ # Splitting into warmup/posterior requires a tune stat!
307+ assert any ("No 'tune' stat" in r .message for r in caplog .records )
308+ assert idata .warmup_posterior .dims ["draw" ] == 0
309+ assert idata .posterior .dims ["draw" ] == 10
310+ else :
311+ assert idata .warmup_posterior .dims ["draw" ] == 4
312+ assert idata .posterior .dims ["draw" ] == 6
313+
306314 for var in rmeta .variables :
307315 assert var .name in set (idata .posterior .keys ())
308316 for svar in rmeta .sample_stats :
0 commit comments