Closed
Description
Description
What is the best way to access the machine epsilon?
I want to compute a default value for a regularization parameter, and need to access the machine epsilon depending on the dtype. Typically:
import jax
import jax.numpy as jnp
def f(x):
# Compute machine epsilon
reg = jnp.finfo(x).eps * 10
return x + reg
jax.jit(f)(jnp.ones(1))
But this raises
jax/_src/core.py:700: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
I end up doing
def _machine_epsilon(x):
if x.dtype == jnp.float32:
return 1e-6
elif x.dtype == jnp.float64:
return 1e-15
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35
jaxlib: 0.4.35
numpy: 1.26.4
python: 3.11.10 (main, Oct 16 2024, 08:56:36) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='ariane', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:05:14 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T8103', machine='arm64')