@@ -29,6 +29,8 @@ def make_runmeta(*, flexibility: bool = False, **kwargs) -> RunMeta:
2929 Variable ("accepted" , "bool" , list ((3 ,)), dims = ["sampler" ]),
3030 # But some stats may refer to the iteration.
3131 Variable ("logp" , "float64" , []),
32+ # String dtypes may be used for more complex information
33+ Variable ("message" , "str" ),
3234 ],
3335 data = [
3436 DataVariable (
@@ -60,8 +62,13 @@ def make_draw(variables: Sequence[Variable]):
6062 )
6163 if "float" in var .dtype :
6264 draw [var .name ] = numpy .random .normal (size = dshape ).astype (var .dtype )
65+ elif var .dtype == "str" :
66+ alphabet = tuple ("abcdef#+*/'" )
67+ words = ["" .join (numpy .random .choice (alphabet , size = numpy .random .randint (3 , 10 )))]
68+ draw [var .name ] = numpy .array (words ).reshape (dshape )
6369 else :
6470 draw [var .name ] = numpy .random .randint (low = 0 , high = 100 , size = dshape ).astype (var .dtype )
71+ assert draw [var .name ].shape == dshape
6572 return draw
6673
6774
@@ -149,7 +156,7 @@ def test__append_get_with_changelings(self, with_stats):
149156 expected = [draw [var .name ] for draw in draws ]
150157 actual = chain .get_draws (var .name )
151158 assert isinstance (actual , numpy .ndarray )
152- if var .name == "changeling " :
159+ if not is_rigid ( var .shape ) or var . dtype == "str " :
153160 # Non-ridid variables are returned as object-arrays.
154161 assert actual .shape == (len (expected ),)
155162 assert actual .dtype == object
@@ -166,9 +173,13 @@ def test__append_get_with_changelings(self, with_stats):
166173 expected = [stat [var .name ] for stat in stats ]
167174 actual = chain .get_stats (var .name )
168175 assert isinstance (actual , numpy .ndarray )
169- if is_rigid ( var .shape ) :
176+ if var .dtype == "str" :
170177 assert tuple (actual .shape ) == tuple (numpy .shape (expected ))
171- assert actual .dtype == var .dtype
178+ # String dtypes have strange names
179+ assert "str" in actual .dtype .name
180+ elif is_rigid (var .shape ):
181+ assert tuple (actual .shape ) == tuple (numpy .shape (expected ))
182+ assert actual .dtype .name == var .dtype
172183 numpy .testing .assert_array_equal (actual , expected )
173184 else :
174185 # Non-ridid variables are returned as object-arrays.
@@ -200,7 +211,7 @@ def test__get_slicing(self, slc: slice):
200211 # "A" are just numbers to make diagnosis easier.
201212 # "B" are dynamically shaped to cover the edge cases.
202213 rmeta = RunMeta (
203- variables = [Variable ("A" , "uint8" )],
214+ variables = [Variable ("A" , "uint8" ), Variable ( "M" , "str" , [ 2 , 3 ]) ],
204215 sample_stats = [Variable ("B" , "uint8" , [2 , 0 ])],
205216 data = [],
206217 )
@@ -209,7 +220,7 @@ def test__get_slicing(self, slc: slice):
209220
210221 # Generate draws and add them to the chain
211222 N = 20
212- draws = [dict (A = n ) for n in range (N )]
223+ draws = [dict (A = numpy . array ( n ) ) for n in range (N )]
213224 stats = [make_draw (rmeta .sample_stats ) for n in range (N )]
214225 for d , s in zip (draws , stats ):
215226 chain .append (d , s )
@@ -222,8 +233,13 @@ def test__get_slicing(self, slc: slice):
222233 act_stats = chain .get_stats ("B" , ** kwargs )
223234 expected_draws = [d ["A" ] for d in draws ][slc or slice (None , None , None )]
224235 expected_stats = [s ["B" ] for s in stats ][slc or slice (None , None , None )]
236+
225237 # Variable "A" has a rigid shape
226238 numpy .testing .assert_array_equal (act_draws , expected_draws )
239+
240+ # Variable "M" is a string matrix
241+ numpy .testing .assert_array_equal (act_draws , expected_draws )
242+
227243 # Stat "B" is dynamically shaped, which means we're dealing with
228244 # dtype=object arrays. These must be checked elementwise.
229245 assert len (act_stats ) == len (expected_stats )
@@ -256,6 +272,7 @@ def test__to_inferencedata(self):
256272 sample_stats = [
257273 Variable ("tune" , "bool" ),
258274 Variable ("sampler_0__logp" , "float32" ),
275+ Variable ("warning" , "str" ),
259276 ],
260277 )
261278 run = self .backend .init_run (rmeta )
0 commit comments