Skip to content

CI: 05/22/25 upstream sync #432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1,605 commits into
base: rocm-main
Choose a base branch
from

Conversation

rocm-repo-management-api-2[bot]
Copy link

Daily sync with upstream

emilyfertig and others added 30 commits May 9, 2025 11:49
PiperOrigin-RevId: 756850393
…mmutable inside `jax.Array` is immutable and `ShapeDtypeStruct` is a duck of `jax.Array` but immutability was never enforced.

**If you are broken by this change, just update your code to use sds.update(...)**

PiperOrigin-RevId: 756852248
PiperOrigin-RevId: 756874608
…ent the unreduced rule. Currently that's only `add`.

PiperOrigin-RevId: 756902404
PiperOrigin-RevId: 756914581
PiperOrigin-RevId: 756989842
* fix leaking of internal symbolic zeros in returned cotangents
* fix a bug around symbolic zero output tangents
PiperOrigin-RevId: 757170420
For recursive config definitions Bazel requires use of a single token notation `--config=value`
This is necessary to ensure that all SMEM reads issued from a current WG
have completed before we schedule the copy (that acts as an SMEM write)!

PiperOrigin-RevId: 757647993
Our current allocation scheme on GPU is unsafe in presence of multiple threads
that might take diverging control paths. We work around this problem using our
favorite trick and simply forbid this!

With this change, `run_scoped(..., collective_axes="wg")` means that the same
allocation will be returned in all programs that only differ in the `wg` axis.
What's more, this call is a user promise that the allocation is a collective that
will be executed by all threads along that axis. Only executing it on a subset is
undefined behavior and in our current Mosaic GPU implementation might lead to deadlocks
due to barriers.

Note that nothing changes for single-threaded kernels, where run_scoped is always
allowed.

PiperOrigin-RevId: 757734362
PiperOrigin-RevId: 757735827
PTX docs are a bit confusing because the type is called e4m3, but
[its description](https://docs.nvidia.com/cuda/parallel-thread-execution/#alternate-floating-point-data-formats)
indicates that it is actually e4m3fn (no infs, limited NaNs).

PiperOrigin-RevId: 757741649
This should be useful for kernels such as FlashAttention since row-wise
reductions can be performed entirely without any communication with other threads.

PiperOrigin-RevId: 757746207
…ntended platforms

Fixes: jax-ml#28594

Currently `lax.platform_dependent` allows specifying code that behaves
differently when lowered on different platforms. However, this function
operates in a confusing way, in that it will create a branch on the
platform, but will lower all branches for the **current** lowering platforms.

For example, in the following code:
```
   lax.platform_dependent(x,
                          cpu=for_cpu, tpu=for_tpu)
```

If we lower for CPU, we lower both `for_cpu` and `for_tpu`
for CPU (!), but only the branch corresponding to `for_cpu`
will actually run.

This is a problem if, e.g., `for_tpu` does not have a lowering
for CPU. We will get an error during lowering. Instead there should
be no error during lowering, because that branch is not actually needed.

We add a new test `test_platform_dependent_with_primitive_with_lowering_error`
to demonstrate this.

The solution implememented here is the Solution A from jax-ml#28594: we
add a `branches_platform` param to the `cond` primitive, which is
propagated by all transformations. This param is used only for the
conditionals arising from `lax.platform_dependendet`.
During lowering we drop the branches corresponding to the platforms
that are not interesting.
dfm and others added 28 commits May 21, 2025 11:15
…sts depend on NVIDIA CUDA wheels hermetically.

The flag is enabled by default.

To disable the dependency, pass `add_pypi_cuda_wheel_deps=False` in the Bazel options.

PiperOrigin-RevId: 761568590
This will make it easier to track down unexpected path mismatches in the future.

PiperOrigin-RevId: 761584888
…for unpacked types and native tiling on TPUv5

PiperOrigin-RevId: 761676578
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times.

This refactor required moving the definitions of a few private utilities from pjit and pxla, because these files are part of the larger jax build target.

PiperOrigin-RevId: 761689391
…full explicit mode

PiperOrigin-RevId: 761708753
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times.

This required moving some internal utilities out of dispatch.py, which is part of the main JAX build rule. I chose api_util.py because they seem to fit there.

PiperOrigin-RevId: 761722054
Next steps:
  - non-tile aligned
  - Clean up fn and utilize it for general changeTiling

PiperOrigin-RevId: 761731600
…fo__ guards after 0.6.1 release.

PiperOrigin-RevId: 761737523
…normal` and other APIs implementing the `Initializer` protocol. Currently it takes `key, shape, dtype` and now we added an optional out_sharding parameter to it.

PiperOrigin-RevId: 761742909
PiperOrigin-RevId: 761745409
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner May 22, 2025 06:02
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot enabled auto-merge (rebase) May 22, 2025 06:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.