diff --git a/CHANGELOG.md b/CHANGELOG.md index 71fe5ac06556..352c1d0e44e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* Changes: + * The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum + supported version until June 2025. + ## jax 0.4.38 (Dec 17, 2024) * Changes: diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 989a8314eb92..4370aa3176aa 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -63,7 +63,7 @@ def has_ext_modules(self): install_requires=[ 'scipy>=1.10', "scipy>=1.11.1; python_version>='3.12'", - 'numpy>=1.24', + 'numpy>=1.25', 'ml_dtypes>=0.2.0', ], url='https://github.com/jax-ml/jax', diff --git a/setup.py b/setup.py index bb0c7841205f..07726d1b67f6 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ def load_version_module(pkg_path): install_requires=[ f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}', 'ml_dtypes>=0.4.0', - 'numpy>=1.24', + 'numpy>=1.25', "numpy>=1.26.0; python_version>='3.12'", 'opt_einsum', 'scipy>=1.10', diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index 02f5ad527c61..cce40788d94c 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -25,8 +25,6 @@ import numpy as np -numpy_version = jtu.numpy_version() - config.parse_flags_with_absl() try: @@ -48,10 +46,6 @@ [dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16], key=lambda x: x.__name__) -# NumPy didn't support bool as a dlpack type until 1.25. -if jtu.numpy_version() < (1, 25, 0): - numpy_dtypes = [dt for dt in numpy_dtypes if dt != jnp.bool_] - cuda_array_interface_dtypes = [dt for dt in dlpack_dtypes if dt != jnp.bfloat16] nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]