@@ -29,6 +29,8 @@ def make_runmeta(*, flexibility: bool = False, **kwargs) -> RunMeta:
2929 Variable ("accepted" , "bool" , list ((3 ,)), dims = ["sampler" ]),
3030 # But some stats may refer to the iteration.
3131 Variable ("logp" , "float64" , []),
32+ # String dtypes may be used for more complex information
33+ Variable ("message" , "str" ),
3234 ],
3335 data = [
3436 DataVariable (
@@ -60,8 +62,16 @@ def make_draw(variables: Sequence[Variable]):
6062 )
6163 if "float" in var .dtype :
6264 draw [var .name ] = numpy .random .normal (size = dshape ).astype (var .dtype )
65+ elif var .dtype == "str" :
66+ alphabet = tuple ("abcdef#+*/'" )
67+ words = [
68+ "" .join (numpy .random .choice (alphabet , size = numpy .random .randint (3 , 10 )))
69+ for _ in range (int (numpy .prod (dshape )))
70+ ]
71+ draw [var .name ] = numpy .array (words , dtype = var .dtype ).reshape (dshape )
6372 else :
6473 draw [var .name ] = numpy .random .randint (low = 0 , high = 100 , size = dshape ).astype (var .dtype )
74+ assert draw [var .name ].shape == dshape
6575 return draw
6676
6777
@@ -149,7 +159,7 @@ def test__append_get_with_changelings(self, with_stats):
149159 expected = [draw [var .name ] for draw in draws ]
150160 actual = chain .get_draws (var .name )
151161 assert isinstance (actual , numpy .ndarray )
152- if var .name == "changeling " :
162+ if not is_rigid ( var .shape ) or var . dtype == "str " :
153163 # Non-ridid variables are returned as object-arrays.
154164 assert actual .shape == (len (expected ),)
155165 assert actual .dtype == object
@@ -166,9 +176,13 @@ def test__append_get_with_changelings(self, with_stats):
166176 expected = [stat [var .name ] for stat in stats ]
167177 actual = chain .get_stats (var .name )
168178 assert isinstance (actual , numpy .ndarray )
169- if is_rigid ( var .shape ) :
179+ if var .dtype == "str" :
170180 assert tuple (actual .shape ) == tuple (numpy .shape (expected ))
171- assert actual .dtype == var .dtype
181+ # String dtypes have strange names
182+ assert "str" in actual .dtype .name
183+ elif is_rigid (var .shape ):
184+ assert tuple (actual .shape ) == tuple (numpy .shape (expected ))
185+ assert actual .dtype .name == var .dtype
172186 numpy .testing .assert_array_equal (actual , expected )
173187 else :
174188 # Non-ridid variables are returned as object-arrays.
@@ -200,7 +214,7 @@ def test__get_slicing(self, slc: slice):
200214 # "A" are just numbers to make diagnosis easier.
201215 # "B" are dynamically shaped to cover the edge cases.
202216 rmeta = RunMeta (
203- variables = [Variable ("A" , "uint8" )],
217+ variables = [Variable ("A" , "uint8" ), Variable ( "M" , "str" , [ 2 , 3 ]) ],
204218 sample_stats = [Variable ("B" , "uint8" , [2 , 0 ])],
205219 data = [],
206220 )
@@ -209,7 +223,7 @@ def test__get_slicing(self, slc: slice):
209223
210224 # Generate draws and add them to the chain
211225 N = 20
212- draws = [dict ( A = n ) for n in range (N )]
226+ draws = [make_draw ( rmeta . variables ) for n in range (N )]
213227 stats = [make_draw (rmeta .sample_stats ) for n in range (N )]
214228 for d , s in zip (draws , stats ):
215229 chain .append (d , s )
@@ -218,12 +232,25 @@ def test__get_slicing(self, slc: slice):
218232 # slc=None in this test means "don't pass it".
219233 # The implementations should default to slc=slice(None, None, None).
220234 kwargs = dict (slc = slc ) if slc is not None else {}
221- act_draws = chain .get_draws ("A" , ** kwargs )
235+ act_draws_A = chain .get_draws ("A" , ** kwargs )
236+ act_draws_M = chain .get_draws ("M" , ** kwargs )
222237 act_stats = chain .get_stats ("B" , ** kwargs )
223- expected_draws = [d ["A" ] for d in draws ][slc or slice (None , None , None )]
238+ expected_draws_A = [d ["A" ] for d in draws ][slc or slice (None , None , None )]
239+ expected_draws_M = [d ["M" ] for d in draws ][slc or slice (None , None , None )]
224240 expected_stats = [s ["B" ] for s in stats ][slc or slice (None , None , None )]
241+
225242 # Variable "A" has a rigid shape
226- numpy .testing .assert_array_equal (act_draws , expected_draws )
243+ if expected_draws_A :
244+ numpy .testing .assert_array_equal (act_draws_A , expected_draws_A )
245+ else :
246+ assert len (act_draws_A ) == 0
247+
248+ # Variable "M" is a string matrix
249+ if expected_draws_M :
250+ numpy .testing .assert_array_equal (act_draws_M , expected_draws_M )
251+ else :
252+ assert len (act_draws_M ) == 0
253+
227254 # Stat "B" is dynamically shaped, which means we're dealing with
228255 # dtype=object arrays. These must be checked elementwise.
229256 assert len (act_stats ) == len (expected_stats )
@@ -256,6 +283,7 @@ def test__to_inferencedata(self):
256283 sample_stats = [
257284 Variable ("tune" , "bool" ),
258285 Variable ("sampler_0__logp" , "float32" ),
286+ Variable ("warning" , "str" ),
259287 ],
260288 )
261289 run = self .backend .init_run (rmeta )
0 commit comments