Skip to content

silent truncation of numpy int64 arrays to int32 range #18385

Open
@mattjj

Description

Description

import jax
import numpy

y = jax.jit(lambda x: x + 0)(numpy.array([2**32]))
print(y)  # Array([0])

Consider fixing!

#15275 was an attempt, but it bitrotted.

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

Metadata

Assignees

Labels

P2 (eventual)This ought to be addressed, but has no schedule at the moment. (Assignee optional)better_errorsImprove the error reporting

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions