jax.random.choice(replace=True) samples 0 probability index #25498
Open
Description
Description
jax.random.choice(replace=True)
will sample 0 probability entry when the input array is large, and the average probability is low (~1e-07
):
sample_prob = np.zeros((7000000,))
sample_prob[:5000000] = 1.0
sample_prob = jnp.array(sample_prob / (sample_prob.sum()))
print(sample_prob.max()) # Output: 2e-07
print(sample_prob.min()) # Output: 0.0
sampled_idxs = jax.random.choice(
jax.random.PRNGKey(0),
a=jnp.arange(len(sample_prob)),
shape=(len(sample_prob),),
p=sample_prob,
replace=True,
)
print((sample_prob[sampled_idxs]).min()) # Output: 0.0, shouldn't happen
The numpy
counter part np.random.choice
behaves correctly:
sample_prob = np.zeros((7000000,)).astype(np.float32)
sample_prob[:5000000] = 1.0
sample_prob = sample_prob / (sample_prob.sum())
print(sample_prob.max()) # Output: 2e-07
print(sample_prob.min()) # Output: 0.0
sampled_idxs = np.random.choice(
a=np.arange(len(sample_prob)),
size=(len(sample_prob),),
p=sample_prob,
replace=True,
)
print((sample_prob[sampled_idxs]).min()) # Output: 2e-07, expected
Seems like an unexpected behavior/bug?
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.38
jaxlib: 0.4.38
numpy: 2.1.3
python: 3.11.8 (stable, redacted, redacted) [Clang 9999.0.0 (be2df95e9281985b61270bb6420ea0eeeffbbe59)]
device info: Tesla V100-SXM2-16GB-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.10.0-smp-1106.20.0.0', version='#1 [v5.10.0-1106.20.0.0] SMP @1728697352', machine='x86_64')