@@ -103,6 +103,9 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
103
103
assert_size_argument_jax_compatible (node )
104
104
105
105
def sample_fn (rng , size , dtype , * parameters ):
106
+ # PyTensor uses empty size to represent size = None
107
+ if jax .numpy .asarray (size ).shape == (0 ,):
108
+ size = None
106
109
return jax_sample_fn (op )(rng , size , out_dtype , * parameters )
107
110
108
111
else :
@@ -161,6 +164,8 @@ def sample_fn(rng, size, dtype, *parameters):
161
164
rng_key = rng ["jax_state" ]
162
165
rng_key , sampling_key = jax .random .split (rng_key , 2 )
163
166
loc , scale = parameters
167
+ if size is None :
168
+ size = jax .numpy .broadcast_arrays (loc , scale )[0 ].shape
164
169
sample = loc + jax_op (sampling_key , size , dtype ) * scale
165
170
rng ["jax_state" ] = rng_key
166
171
return (rng , sample )
@@ -184,15 +189,16 @@ def sample_fn(rng, size, dtype, p):
184
189
185
190
186
191
@jax_sample_fn .register (ptr .CategoricalRV )
187
- def jax_sample_fn_no_dtype (op ):
188
- """Generic JAX implementation of random variables."""
189
- name = op .name
190
- jax_op = getattr (jax .random , name )
192
+ def jax_sample_fn_categorical (op ):
193
+ """JAX implementation of `CategoricalRV`."""
191
194
192
- def sample_fn (rng , size , dtype , * parameters ):
195
+ # We need a separate dispatch because Categorical expects logits in JAX
196
+ def sample_fn (rng , size , dtype , p ):
193
197
rng_key = rng ["jax_state" ]
194
198
rng_key , sampling_key = jax .random .split (rng_key , 2 )
195
- sample = jax_op (sampling_key , * parameters , shape = size )
199
+
200
+ logits = jax .scipy .special .logit (p )
201
+ sample = jax .random .categorical (sampling_key , logits = logits , shape = size )
196
202
rng ["jax_state" ] = rng_key
197
203
return (rng , sample )
198
204
@@ -243,6 +249,8 @@ def jax_sample_fn_shape_scale(op):
243
249
def sample_fn (rng , size , dtype , shape , scale ):
244
250
rng_key = rng ["jax_state" ]
245
251
rng_key , sampling_key = jax .random .split (rng_key , 2 )
252
+ if size is None :
253
+ size = jax .numpy .broadcast_arrays (shape , scale )[0 ].shape
246
254
sample = jax_op (sampling_key , shape , size , dtype ) * scale
247
255
rng ["jax_state" ] = rng_key
248
256
return (rng , sample )
@@ -254,10 +262,11 @@ def sample_fn(rng, size, dtype, shape, scale):
254
262
def jax_sample_fn_exponential (op ):
255
263
"""JAX implementation of `ExponentialRV`."""
256
264
257
- def sample_fn (rng , size , dtype , * parameters ):
265
+ def sample_fn (rng , size , dtype , scale ):
258
266
rng_key = rng ["jax_state" ]
259
267
rng_key , sampling_key = jax .random .split (rng_key , 2 )
260
- (scale ,) = parameters
268
+ if size is None :
269
+ size = jax .numpy .asarray (scale ).shape
261
270
sample = jax .random .exponential (sampling_key , size , dtype ) * scale
262
271
rng ["jax_state" ] = rng_key
263
272
return (rng , sample )
@@ -269,14 +278,11 @@ def sample_fn(rng, size, dtype, *parameters):
269
278
def jax_sample_fn_t (op ):
270
279
"""JAX implementation of `StudentTRV`."""
271
280
272
- def sample_fn (rng , size , dtype , * parameters ):
281
+ def sample_fn (rng , size , dtype , df , loc , scale ):
273
282
rng_key = rng ["jax_state" ]
274
283
rng_key , sampling_key = jax .random .split (rng_key , 2 )
275
- (
276
- df ,
277
- loc ,
278
- scale ,
279
- ) = parameters
284
+ if size is None :
285
+ size = jax .numpy .broadcast_arrays (df , loc , scale )[0 ].shape
280
286
sample = loc + jax .random .t (sampling_key , df , size , dtype ) * scale
281
287
rng ["jax_state" ] = rng_key
282
288
return (rng , sample )
0 commit comments