Skip to content

jax.random.choice(replace=True) samples 0 probability index #25498

Open
@LemonATsu

Description

Description

jax.random.choice(replace=True) will sample 0 probability entry when the input array is large, and the average probability is low (~1e-07):

sample_prob = np.zeros((7000000,))
sample_prob[:5000000] = 1.0
sample_prob = jnp.array(sample_prob / (sample_prob.sum()))
print(sample_prob.max())  # Output: 2e-07
print(sample_prob.min())  # Output: 0.0

sampled_idxs = jax.random.choice(
    jax.random.PRNGKey(0),
    a=jnp.arange(len(sample_prob)),
    shape=(len(sample_prob),),
    p=sample_prob,
    replace=True,
)

print((sample_prob[sampled_idxs]).min())  # Output: 0.0, shouldn't happen

The numpy counter part np.random.choice behaves correctly:

sample_prob = np.zeros((7000000,)).astype(np.float32)
sample_prob[:5000000] = 1.0
sample_prob = sample_prob / (sample_prob.sum())
print(sample_prob.max())  # Output: 2e-07
print(sample_prob.min())  # Output: 0.0

sampled_idxs = np.random.choice(
    a=np.arange(len(sample_prob)),
    size=(len(sample_prob),),
    p=sample_prob,
    replace=True,
)

print((sample_prob[sampled_idxs]).min())  # Output: 2e-07, expected

Seems like an unexpected behavior/bug?

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.1.3
python: 3.11.8 (stable, redacted, redacted) [Clang 9999.0.0 (be2df95e9281985b61270bb6420ea0eeeffbbe59)]
device info: Tesla V100-SXM2-16GB-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.10.0-smp-1106.20.0.0', version='#1 [v5.10.0-1106.20.0.0] SMP @1728697352', machine='x86_64')

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingduplicateThis issue or pull request already exists

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions