Skip to content

pytest: use importlib mode by default #28650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2025

Conversation

olupton
Copy link
Contributor

@olupton olupton commented May 9, 2025

Otherwise the default (prepend) mode will add <jax_src_dir>/tests and <jax_src_dir> to the start of sys.path.
This is (has always been) fragile, because <jax_src_dir> has jax/ and jaxlib/ subdirectories, but it recently broke as setuptools 80 has a change in behaviour for editable installations, with the result that if jaxlib is installed editable then import jaxlib with <jax_src_dir> in sys.path will try to import from <jax_src_dir>/jaxlib instead of the editable install location, and fail.

See also:

Otherwise the default (prepend) mode will add `<jax_src_dir>/tests` and
`<jax_src_dir>` to the start of `sys.path`. This is (has always been)
fragile, because `<jax_src_dir>` has `jax/` and `jaxlib/`
subdirectories, but it recently broke as `setuptools` 80 has a change in
behaviour for editable installations, with the result that if `jaxlib`
is installed editable then `import jaxlib` with `<jax_src_dir>` in
`sys.path` will try to import from `<jax_src_dir>/jaxlib` instead of the
editable install location, and fail.
@jakevdp jakevdp self-assigned this May 9, 2025
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! I didn't know about --import _mode before, but it looks like importlib is definitely the cleanest approach.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels May 9, 2025
@copybara-service copybara-service bot merged commit a4edfac into jax-ml:main May 9, 2025
21 checks passed
@jakevdp
Copy link
Collaborator

jakevdp commented May 12, 2025

Hi - it looks like this broke our nightly tests (see #28671) so I think I'm going to roll back this change.

@olupton
Copy link
Contributor Author

olupton commented May 12, 2025

Hi - it looks like this broke our nightly tests (see #28671) so I think I'm going to roll back this change.

I guess this is because of

https://docs.pytest.org/en/stable/explanation/pythonpath.html#import-modes

Disadvantages:

Test modules can’t import each other.
Testing utility modules in the tests directories (for example a tests.helpers module containing test-related functions/classes) are not importable. The recommendation in this case it to place testing utility modules together with the application/library code, for example app.testing.helpers.

and the relevant part is not included in the wheel installation?
With an editable installation of jax it worked.

@jakevdp
Copy link
Collaborator

jakevdp commented May 12, 2025

The error arises due to missing modules, but they're not test-specific modules. I'm seeing the following three files:

tests/pallas/export_back_compat_pallas_test.py:28: in <module>
    from jax._src.internal_test_util import export_back_compat_test_util as bctu
E   ModuleNotFoundError: No module named 'jax._src.internal_test_util'

tests/pallas/tpu_pallas_random_test.py:24: in <module>
    from jax.experimental.pallas.ops.tpu.random import philox  # pylint: disable=unused-import  # noqa: F401
E   ModuleNotFoundError: No module named 'jax.experimental.pallas.ops.tpu.random'

tests/mosaic/matmul_test.py:33: in <module>
    from jax.experimental.mosaic.gpu.examples import matmul
E   ModuleNotFoundError: No module named 'jax.experimental.mosaic.gpu.examples'

I know we had some changes to our wheel build process recently – perhaps these files are missing from the package distribution somehow?

@jakevdp
Copy link
Collaborator

jakevdp commented May 12, 2025

Ah, they're explicitly excluded from the package distribution: https://github.com/jax-ml/jax/blame/76e5bc6f5d5c89c9d56d3211dba1953ae8d20cde/setup.py#L60

This was done deliberately in 8fbe3b1

I'm still not sure why jax.experimental.pallas and jax.experimental.mosaic fail to import.

@olupton
Copy link
Contributor Author

olupton commented May 12, 2025

Ah, they're explicitly excluded from the package distribution: https://github.com/jax-ml/jax/blame/76e5bc6f5d5c89c9d56d3211dba1953ae8d20cde/setup.py#L60

This was done deliberately in 8fbe3b1

I'm still not sure why jax.experimental.pallas and jax.experimental.mosaic fail to import.

The *examples* pattern I assume matches jax.experimental.mosaic.gpu.examples. For the other one, is https://github.com/jax-ml/jax/tree/main/jax/experimental/pallas/ops/tpu/random missing an __init__.py?

copybara-service bot pushed a commit that referenced this pull request May 15, 2025
Re-land #28650, fixing build failures, I hope!

PiperOrigin-RevId: 759244911
copybara-service bot pushed a commit that referenced this pull request May 15, 2025
Re-land #28650, fixing build failures, I hope!

PiperOrigin-RevId: 759244911
copybara-service bot pushed a commit that referenced this pull request May 15, 2025
Re-land #28650, fixing build failures, I hope!

PiperOrigin-RevId: 759244911
copybara-service bot pushed a commit that referenced this pull request May 15, 2025
Re-land #28650, fixing build failures, I hope!

PiperOrigin-RevId: 759244911
copybara-service bot pushed a commit that referenced this pull request May 16, 2025
Re-land #28650, fixing build failures, I hope!

PiperOrigin-RevId: 759244911
copybara-service bot pushed a commit that referenced this pull request May 16, 2025
This is an attempt to re-land #28650, fixing build failures.

**Motivation**

#28650 was motivated by recent changes in setuptools, which caused test failures with editable installs, but it also exposes a potentially larger issue with our approach to testing with pytest. As far as I understand it, the default pytest behavior is to prepend the working directory to `sys.path`, meaning that `jax` is imported from the source directory, rather than the installed version. Switching to the `importlib` import mode means that we correctly test against the installed version of `jax`, which seems like what we typically want to do.

The catch is that then we need to explicitly package any test utilities into the distribution. We don't currently package test-specific utilities like `internal_test_util` with JAX, but these utilities were still available to tests since they live within the `jax` source tree. This breaks when using `importlib` import mode, and a non-editable install of JAX.

**Solutions**

The approach that I've taken here is to explicitly package everything needed by the tests into the `jax` distribution. This means that we can correctly test against the _installed_ version of JAX when using pytest.

This solution isn't ideal because it means that we're distributing `jax` submodules that aren't actually required except when running the test suite, but this seems like a small price to pay to me.

**Alternatives**

One different approach that we could take would be to only support using pytest with _editable_ installs of JAX. This would work because the required files would still be discoverable in an editable install because they live within the source tree. In fact, before this change, most of our CI jobs actually did install an editable distribution (which is why the failures in #28650 weren't caught in pre-submit!). The problem with this approach is that we're not actually testing JAX as it is used when installed from a distribution, and it wouldn't catch things like missing submodules. I think it's much better to test against the installed distribution!

A more extreme approach would be to switch JAX to a `src/jax` and `src/jaxlib` layout (i.e. moving `jax` and `jaxlib` out of the root directory) as recommended by the Python packaging docs. Unfortunately this would be complicated with the way JAX is distributed internally at Google, so I think that's probably a non-starter.

PiperOrigin-RevId: 759244911
copybara-service bot pushed a commit that referenced this pull request May 22, 2025
This is an attempt to re-land #28650, fixing build failures.

**Motivation**

#28650 was motivated by recent changes in setuptools, which caused test failures with editable installs, but it also exposes a potentially larger issue with our approach to testing with pytest. As far as I understand it, the default pytest behavior is to prepend the working directory to `sys.path`, meaning that `jax` is imported from the source directory, rather than the installed version. Switching to the `importlib` import mode means that we correctly test against the installed version of `jax`, which seems like what we typically want to do.

The catch is that then we need to explicitly package any test utilities into the distribution. We don't currently package test-specific utilities like `internal_test_util` with JAX, but these utilities were still available to tests since they live within the `jax` source tree. This breaks when using `importlib` import mode, and a non-editable install of JAX.

**Solutions**

The approach that I've taken here is to explicitly package everything needed by the tests into the `jax` distribution. This means that we can correctly test against the _installed_ version of JAX when using pytest.

This solution isn't ideal because it means that we're distributing `jax` submodules that aren't actually required except when running the test suite, but this seems like a small price to pay to me.

**Alternatives**

One different approach that we could take would be to only support using pytest with _editable_ installs of JAX. This would work because the required files would still be discoverable in an editable install because they live within the source tree. In fact, before this change, most of our CI jobs actually did install an editable distribution (which is why the failures in #28650 weren't caught in pre-submit!). The problem with this approach is that we're not actually testing JAX as it is used when installed from a distribution, and it wouldn't catch things like missing submodules. I think it's much better to test against the installed distribution!

A more extreme approach would be to switch JAX to a `src/jax` and `src/jaxlib` layout (i.e. moving `jax` and `jaxlib` out of the root directory) as recommended by the Python packaging docs. Unfortunately this would be complicated with the way JAX is distributed internally at Google, so I think that's probably a non-starter.

PiperOrigin-RevId: 759244911
copybara-service bot pushed a commit that referenced this pull request May 22, 2025
This is an attempt to re-land #28650, fixing build failures.

**Motivation**

#28650 was motivated by recent changes in setuptools, which caused test failures with editable installs, but it also exposes a potentially larger issue with our approach to testing with pytest. As far as I understand it, the default pytest behavior is to prepend the working directory to `sys.path`, meaning that `jax` is imported from the source directory, rather than the installed version. Switching to the `importlib` import mode means that we correctly test against the installed version of `jax`, which seems like what we typically want to do.

The catch is that then we need to explicitly package any test utilities into the distribution. We don't currently package test-specific utilities like `internal_test_util` with JAX, but these utilities were still available to tests since they live within the `jax` source tree. This breaks when using `importlib` import mode, and a non-editable install of JAX.

**Solutions**

The approach that I've taken here is to explicitly package everything needed by the tests into the `jax` distribution. This means that we can correctly test against the _installed_ version of JAX when using pytest.

This solution isn't ideal because it means that we're distributing `jax` submodules that aren't actually required except when running the test suite, but this seems like a small price to pay to me.

**Alternatives**

One different approach that we could take would be to only support using pytest with _editable_ installs of JAX. This would work because the required files would still be discoverable in an editable install because they live within the source tree. In fact, before this change, most of our CI jobs actually did install an editable distribution (which is why the failures in #28650 weren't caught in pre-submit!). The problem with this approach is that we're not actually testing JAX as it is used when installed from a distribution, and it wouldn't catch things like missing submodules. I think it's much better to test against the installed distribution!

A more extreme approach would be to switch JAX to a `src/jax` and `src/jaxlib` layout (i.e. moving `jax` and `jaxlib` out of the root directory) as recommended by the Python packaging docs. Unfortunately this would be complicated with the way JAX is distributed internally at Google, so I think that's probably a non-starter.

PiperOrigin-RevId: 759244911
copybara-service bot pushed a commit that referenced this pull request May 23, 2025
This is an attempt to re-land #28650, fixing build failures.

**Motivation**

#28650 was motivated by recent changes in setuptools, which caused test failures with editable installs, but it also exposes a potentially larger issue with our approach to testing with pytest. As far as I understand it, the default pytest behavior is to prepend the working directory to `sys.path`, meaning that `jax` is imported from the source directory, rather than the installed version. Switching to the `importlib` import mode means that we correctly test against the installed version of `jax`, which seems like what we typically want to do.

The catch is that then we need to explicitly package any test utilities into the distribution. We don't currently package test-specific utilities like `internal_test_util` with JAX, but these utilities were still available to tests since they live within the `jax` source tree. This breaks when using `importlib` import mode, and a non-editable install of JAX.

**Solutions**

The approach that I've taken here is to explicitly package everything needed by the tests into the `jax` distribution. This means that we can correctly test against the _installed_ version of JAX when using pytest.

This solution isn't ideal because it means that we're distributing `jax` submodules that aren't actually required except when running the test suite, but this seems like a small price to pay to me.

**Alternatives**

One different approach that we could take would be to only support using pytest with _editable_ installs of JAX. This would work because the required files would still be discoverable in an editable install because they live within the source tree. In fact, before this change, most of our CI jobs actually did install an editable distribution (which is why the failures in #28650 weren't caught in pre-submit!). The problem with this approach is that we're not actually testing JAX as it is used when installed from a distribution, and it wouldn't catch things like missing submodules. I think it's much better to test against the installed distribution!

A more extreme approach would be to switch JAX to a `src/jax` and `src/jaxlib` layout (i.e. moving `jax` and `jaxlib` out of the root directory) as recommended by the Python packaging docs. Unfortunately this would be complicated with the way JAX is distributed internally at Google, so I think that's probably a non-starter.

PiperOrigin-RevId: 759244911
copybara-service bot pushed a commit that referenced this pull request May 23, 2025
This is an attempt to re-land #28650, fixing build failures.

**Motivation**

#28650 was motivated by recent changes in setuptools, which caused test failures with editable installs, but it also exposes a potentially larger issue with our approach to testing with pytest. As far as I understand it, the default pytest behavior is to prepend the working directory to `sys.path`, meaning that `jax` is imported from the source directory, rather than the installed version. Switching to the `importlib` import mode means that we correctly test against the installed version of `jax`, which seems like what we typically want to do.

The catch is that then we need to explicitly package any test utilities into the distribution. We don't currently package test-specific utilities like `internal_test_util` with JAX, but these utilities were still available to tests since they live within the `jax` source tree. This breaks when using `importlib` import mode, and a non-editable install of JAX.

**Solutions**

The approach that I've taken here is to explicitly package everything needed by the tests into the `jax` distribution. This means that we can correctly test against the _installed_ version of JAX when using pytest.

This solution isn't ideal because it means that we're distributing `jax` submodules that aren't actually required except when running the test suite, but this seems like a small price to pay to me.

**Alternatives**

One different approach that we could take would be to only support using pytest with _editable_ installs of JAX. This would work because the required files would still be discoverable in an editable install because they live within the source tree. In fact, before this change, most of our CI jobs actually did install an editable distribution (which is why the failures in #28650 weren't caught in pre-submit!). The problem with this approach is that we're not actually testing JAX as it is used when installed from a distribution, and it wouldn't catch things like missing submodules. I think it's much better to test against the installed distribution!

A more extreme approach would be to switch JAX to a `src/jax` and `src/jaxlib` layout (i.e. moving `jax` and `jaxlib` out of the root directory) as recommended by the Python packaging docs. Unfortunately this would be complicated with the way JAX is distributed internally at Google, so I think that's probably a non-starter.

PiperOrigin-RevId: 762419160
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants