Skip to content

jnp.finfo(x).eps is not hashable #25571

Closed
@gaspardbb

Description

Description

What is the best way to access the machine epsilon?
I want to compute a default value for a regularization parameter, and need to access the machine epsilon depending on the dtype. Typically:

import jax
import jax.numpy as jnp


def f(x):
    # Compute machine epsilon
    reg = jnp.finfo(x).eps * 10
    return x + reg


jax.jit(f)(jnp.ones(1))

But this raises

jax/_src/core.py:700: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.

I end up doing

def _machine_epsilon(x):
    if x.dtype == jnp.float32:
        return 1e-6
    elif x.dtype == jnp.float64:
        return 1e-15

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35
jaxlib: 0.4.35
numpy:  1.26.4
python: 3.11.10 (main, Oct 16 2024, 08:56:36) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='ariane', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:05:14 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T8103', machine='arm64')

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions