Skip to content

Commit 90da9e6

Browse files
committed
Fix bug in JAX implementation of RandomVariables with implicit size
1 parent 17a5e42 commit 90da9e6

File tree

2 files changed

+48
-14
lines changed

2 files changed

+48
-14
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
103103
assert_size_argument_jax_compatible(node)
104104

105105
def sample_fn(rng, size, dtype, *parameters):
106+
# PyTensor uses empty size to represent size = None
107+
if jax.numpy.asarray(size).shape == (0,):
108+
size = None
106109
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
107110

108111
else:
@@ -161,6 +164,8 @@ def sample_fn(rng, size, dtype, *parameters):
161164
rng_key = rng["jax_state"]
162165
rng_key, sampling_key = jax.random.split(rng_key, 2)
163166
loc, scale = parameters
167+
if size is None:
168+
size = jax.numpy.broadcast_arrays(loc, scale)[0].shape
164169
sample = loc + jax_op(sampling_key, size, dtype) * scale
165170
rng["jax_state"] = rng_key
166171
return (rng, sample)
@@ -184,15 +189,16 @@ def sample_fn(rng, size, dtype, p):
184189

185190

186191
@jax_sample_fn.register(ptr.CategoricalRV)
187-
def jax_sample_fn_no_dtype(op):
188-
"""Generic JAX implementation of random variables."""
189-
name = op.name
190-
jax_op = getattr(jax.random, name)
192+
def jax_sample_fn_categorical(op):
193+
"""JAX implementation of `CategoricalRV`."""
191194

192-
def sample_fn(rng, size, dtype, *parameters):
195+
# We need a separate dispatch because Categorical expects logits in JAX
196+
def sample_fn(rng, size, dtype, p):
193197
rng_key = rng["jax_state"]
194198
rng_key, sampling_key = jax.random.split(rng_key, 2)
195-
sample = jax_op(sampling_key, *parameters, shape=size)
199+
200+
logits = jax.scipy.special.logit(p)
201+
sample = jax.random.categorical(sampling_key, logits=logits, shape=size)
196202
rng["jax_state"] = rng_key
197203
return (rng, sample)
198204

@@ -243,6 +249,8 @@ def jax_sample_fn_shape_scale(op):
243249
def sample_fn(rng, size, dtype, shape, scale):
244250
rng_key = rng["jax_state"]
245251
rng_key, sampling_key = jax.random.split(rng_key, 2)
252+
if size is None:
253+
size = jax.numpy.broadcast_arrays(shape, scale)[0].shape
246254
sample = jax_op(sampling_key, shape, size, dtype) * scale
247255
rng["jax_state"] = rng_key
248256
return (rng, sample)
@@ -254,10 +262,11 @@ def sample_fn(rng, size, dtype, shape, scale):
254262
def jax_sample_fn_exponential(op):
255263
"""JAX implementation of `ExponentialRV`."""
256264

257-
def sample_fn(rng, size, dtype, *parameters):
265+
def sample_fn(rng, size, dtype, scale):
258266
rng_key = rng["jax_state"]
259267
rng_key, sampling_key = jax.random.split(rng_key, 2)
260-
(scale,) = parameters
268+
if size is None:
269+
size = jax.numpy.asarray(scale).shape
261270
sample = jax.random.exponential(sampling_key, size, dtype) * scale
262271
rng["jax_state"] = rng_key
263272
return (rng, sample)
@@ -269,14 +278,11 @@ def sample_fn(rng, size, dtype, *parameters):
269278
def jax_sample_fn_t(op):
270279
"""JAX implementation of `StudentTRV`."""
271280

272-
def sample_fn(rng, size, dtype, *parameters):
281+
def sample_fn(rng, size, dtype, df, loc, scale):
273282
rng_key = rng["jax_state"]
274283
rng_key, sampling_key = jax.random.split(rng_key, 2)
275-
(
276-
df,
277-
loc,
278-
scale,
279-
) = parameters
284+
if size is None:
285+
size = jax.numpy.broadcast_arrays(df, loc, scale)[0].shape
280286
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
281287
rng["jax_state"] = rng_key
282288
return (rng, sample)

tests/link/jax/test_random.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,34 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
509509
assert test_res.pvalue > 0.01
510510

511511

512+
@pytest.mark.parametrize(
513+
"rv_fn",
514+
[
515+
lambda param_that_implies_size: ptr.normal(
516+
loc=0, scale=pt.exp(param_that_implies_size)
517+
),
518+
lambda param_that_implies_size: ptr.exponential(
519+
scale=pt.exp(param_that_implies_size)
520+
),
521+
lambda param_that_implies_size: ptr.gamma(
522+
shape=1, scale=pt.exp(param_that_implies_size)
523+
),
524+
lambda param_that_implies_size: ptr.t(
525+
df=3, loc=param_that_implies_size, scale=1
526+
),
527+
],
528+
)
529+
def test_size_implied_by_broadcasted_parameters(rv_fn):
530+
# We need a parameter with untyped shapes to test broadcasting does not result in identical draws
531+
param_that_implies_size = pt.matrix("param_that_implies_size", shape=(None, None))
532+
533+
rv = rv_fn(param_that_implies_size)
534+
draws = rv.eval({param_that_implies_size: np.zeros((2, 2))}, mode=jax_mode)
535+
536+
assert draws.shape == (2, 2)
537+
assert np.unique(draws).size == 4
538+
539+
512540
@pytest.mark.parametrize("size", [(), (4,)])
513541
def test_random_bernoulli(size):
514542
rng = shared(np.random.RandomState(123))

0 commit comments

Comments
 (0)