silent truncation of numpy int64 arrays to int32 range #18385
Open
Description
Description
import jax
import numpy
y = jax.jit(lambda x: x + 0)(numpy.array([2**32]))
print(y) # Array([0])
Consider fixing!
#15275 was an attempt, but it bitrotted.
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response