Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When caching is enabled, also enable XLA caching features as well #22899

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

trevor-m
Copy link

@trevor-m trevor-m commented Aug 6, 2024

This PR makes it easier to enable all of the caching features in JAX and XLA with a single option. Now, when the JAX persistent cache is enabled (JAX_COMPILATION_CACHE_DIR), some XLA caching features will also be enabled to subdirectories of the JAX cache dir. The XLA caching features that are used can be selected via JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES.

Currently, there is an issue related to kernel naming when both xla_gpu_kernel_cache_file and the JAX persistent cache are enabled together, so only the autotune cache is enabled by default now. Once this is fixed, the default value of JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES should be all.

Requires openxla/xla#15636
Requires openxla/xla#18450

@trevor-m
Copy link
Author

trevor-m commented Aug 6, 2024

@nouiz

docs/persistent_compilation_cache.md Outdated Show resolved Hide resolved

* `none`: don't enable any extra XLA caching features

* `xla_gpu_kernel_cache_file`: only enable the kernel cache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the issue with that one? Should we document it here?
Is it just that it doesn't work, or that it give hash collision, or crashes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a crash which looks like this:

2024-07-30 18:04:57.120490: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:272] getting function input_concatenate_fusion from module 0x55d95b91f8d0
E0730 18:04:57.120755   52461 pjrt_stream_executor_client.cc:3067] Execution of replica 0 failed: NOT_FOUND: Failed to get module function: CUDA_ERROR_NOT_FOUND: named symbol not found

@sergachev is taking a look at it and found it can be reproduced using bazel test --test_env=XLA_FLAGS="--xla_gpu_enable_llvm_module_compilation_parallelism --xla_gpu_kernel_cache_file=/dev/shm/xla.kernel.cache" tests/compilation_cache_test_gpu

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, that issue is fixed in XLA: openxla/xla#15998
But I think it is better to not able all XLA caches at the same time for better testing.
We can have a follow up PR to expand it after the next JAX releases.

@nouiz
Copy link
Collaborator

nouiz commented Aug 20, 2024

The required XLA PR is merged: openxla/xla#15636
@hawkinsp can you review this PR?

@nouiz nouiz requested a review from hawkinsp August 20, 2024 23:50
Copy link
Collaborator

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change is fine, but there are CI failures (possibly stale).

@trevor-m
Copy link
Author

The change is fine, but there are CI failures (possibly stale).

@hawkinsp Thanks for reviewing! I've rebased which should fix the CI failures

@hawkinsp
Copy link
Collaborator

One more thing: please squash your commits.

@mattjj mattjj added the pull ready Ready for copybara import and testing label Oct 16, 2024
@dfm
Copy link
Collaborator

dfm commented Oct 16, 2024

@trevor-m — Thanks for your patience here! Can you rebase your PR onto the current main branch? We'll get this in ASAP after that. Thanks!

@trevor-m
Copy link
Author

@dfm Thanks for looking at this. However, we may need to hold off merging this a bit longer. We think there will be issues when using this feature with multihost. To solve it, we can set xla_gpu_experimental_autotune_cache_mode to update for rank 0 only and set it to read for the other ranks. We will need to expose that flag in the xla python bindings first.

We will need to do something similar for the kernel cache.

@trevor-m
Copy link
Author

@dfm I've opened openxla/xla#18450 to expose the cache mode and updated this PR to set it to update for process 0 and read-only for the other processes. I confirmed this fixes the issue with multihost.

@dfm dfm self-assigned this Oct 17, 2024
jax/_src/compiler.py Outdated Show resolved Hide resolved
jax/_src/config.py Outdated Show resolved Hide resolved
Add unit test

Fix typechecker

Set caching mode depending on process id
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants