forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 4
CI: 05/22/25 upstream sync #432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rocm-repo-management-api-2
wants to merge
1,605
commits into
rocm-main
Choose a base branch
from
ci-upstream-sync-199_1
base: rocm-main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…mmutable inside `jax.Array` is immutable and `ShapeDtypeStruct` is a duck of `jax.Array` but immutability was never enforced. **If you are broken by this change, just update your code to use sds.update(...)** PiperOrigin-RevId: 756852248
These are currently thread-unsafe due to python/cpython#132817
PiperOrigin-RevId: 756874608
PiperOrigin-RevId: 756887335
PiperOrigin-RevId: 756887370
PiperOrigin-RevId: 756892123
…ent the unreduced rule. Currently that's only `add`. PiperOrigin-RevId: 756902404
PiperOrigin-RevId: 756914581
PiperOrigin-RevId: 756989842
* fix leaking of internal symbolic zeros in returned cotangents * fix a bug around symbolic zero output tangents
http://github.com/openxla/xla/commit/80924f3d144737d14758d8a92b236d90c8ec8cb9. PiperOrigin-RevId: 757132575
PiperOrigin-RevId: 757170420
PiperOrigin-RevId: 757200726
http://github.com/openxla/xla/commit/633c9abd097a2cf20884d29da51cc53b6e7144b5. PiperOrigin-RevId: 757401878
PiperOrigin-RevId: 757608204
…in TMEMRef PiperOrigin-RevId: 757628207
For recursive config definitions Bazel requires use of a single token notation `--config=value`
This is necessary to ensure that all SMEM reads issued from a current WG have completed before we schedule the copy (that acts as an SMEM write)! PiperOrigin-RevId: 757647993
http://github.com/openxla/xla/commit/6ad6ae3dafa9868708e54de10e3aeafb081a71f2. PiperOrigin-RevId: 757728274
Our current allocation scheme on GPU is unsafe in presence of multiple threads that might take diverging control paths. We work around this problem using our favorite trick and simply forbid this! With this change, `run_scoped(..., collective_axes="wg")` means that the same allocation will be returned in all programs that only differ in the `wg` axis. What's more, this call is a user promise that the allocation is a collective that will be executed by all threads along that axis. Only executing it on a subset is undefined behavior and in our current Mosaic GPU implementation might lead to deadlocks due to barriers. Note that nothing changes for single-threaded kernels, where run_scoped is always allowed. PiperOrigin-RevId: 757734362
PiperOrigin-RevId: 757735827
PTX docs are a bit confusing because the type is called e4m3, but [its description](https://docs.nvidia.com/cuda/parallel-thread-execution/#alternate-floating-point-data-formats) indicates that it is actually e4m3fn (no infs, limited NaNs). PiperOrigin-RevId: 757741649
This should be useful for kernels such as FlashAttention since row-wise reductions can be performed entirely without any communication with other threads. PiperOrigin-RevId: 757746207
PiperOrigin-RevId: 757755328
PiperOrigin-RevId: 757761831
…ntended platforms Fixes: jax-ml#28594 Currently `lax.platform_dependent` allows specifying code that behaves differently when lowered on different platforms. However, this function operates in a confusing way, in that it will create a branch on the platform, but will lower all branches for the **current** lowering platforms. For example, in the following code: ``` lax.platform_dependent(x, cpu=for_cpu, tpu=for_tpu) ``` If we lower for CPU, we lower both `for_cpu` and `for_tpu` for CPU (!), but only the branch corresponding to `for_cpu` will actually run. This is a problem if, e.g., `for_tpu` does not have a lowering for CPU. We will get an error during lowering. Instead there should be no error during lowering, because that branch is not actually needed. We add a new test `test_platform_dependent_with_primitive_with_lowering_error` to demonstrate this. The solution implememented here is the Solution A from jax-ml#28594: we add a `branches_platform` param to the `cond` primitive, which is propagated by all transformations. This param is used only for the conditionals arising from `lax.platform_dependendet`. During lowering we drop the branches corresponding to the platforms that are not interesting.
PiperOrigin-RevId: 761534198
PiperOrigin-RevId: 761545482
PiperOrigin-RevId: 761552874
…sts depend on NVIDIA CUDA wheels hermetically. The flag is enabled by default. To disable the dependency, pass `add_pypi_cuda_wheel_deps=False` in the Bazel options. PiperOrigin-RevId: 761568590
…ypes Fix: jax-ml#28416 PiperOrigin-RevId: 761577943
PiperOrigin-RevId: 761578503
This will make it easier to track down unexpected path mismatches in the future. PiperOrigin-RevId: 761584888
…-12.9 PiperOrigin-RevId: 761587875
PiperOrigin-RevId: 761612390
…for unpacked types and native tiling on TPUv5 PiperOrigin-RevId: 761676578
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. This refactor required moving the definitions of a few private utilities from pjit and pxla, because these files are part of the larger jax build target. PiperOrigin-RevId: 761689391
PiperOrigin-RevId: 761690584
…ive tiling PiperOrigin-RevId: 761692972
…full explicit mode PiperOrigin-RevId: 761708753
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. This required moving some internal utilities out of dispatch.py, which is part of the main JAX build rule. I chose api_util.py because they seem to fit there. PiperOrigin-RevId: 761722054
Next steps: - non-tile aligned - Clean up fn and utilize it for general changeTiling PiperOrigin-RevId: 761731600
…fo__ guards after 0.6.1 release. PiperOrigin-RevId: 761737523
…normal` and other APIs implementing the `Initializer` protocol. Currently it takes `key, shape, dtype` and now we added an optional out_sharding parameter to it. PiperOrigin-RevId: 761742909
PiperOrigin-RevId: 761745409
PiperOrigin-RevId: 761758158
PiperOrigin-RevId: 761806312
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream