Skip to content

CI: 05/02/25 upstream sync #399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1,191 commits into
base: rocm-main
Choose a base branch
from

Conversation

rocm-repo-management-api-2[bot]
Copy link

Daily sync with upstream

apaszke and others added 30 commits April 23, 2025 05:32
…matting

There are no restrictions on window sizes in that backend.

Also, replace Markdown quotations with Note/Warning blocks in the GPU reference
for added clarity.

PiperOrigin-RevId: 750555285
PiperOrigin-RevId: 750575644
PiperOrigin-RevId: 750600271
At the moment mypy isn't correctly detecting errors related to jaxlib. In a future change this will be fixed, and this PR fixes errors that will be revealed by that change.

PiperOrigin-RevId: 750603531
Users should now be able to instantiate aliases `smem` buffers by using an
`RefUnion`, which takes a variadic number of trees of refs as an
input. `RefUnion` represents a union/coproduct of all its operands, and its
operands groups alias (overlap in memory), while the elements within the groups
represent products, and their operands are consecutive in memory.

The resulting aliased `smem` ref can then be unfolded into a flat structure
using assignment inside the kernel. Here is an example:

```
@functools.partial(
  pallas_call,
  out_shape=jax.ShapeDtypeStruct([128], jnp.float32),
  in_specs=[pl.BlockSpec((256,))],
  out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM),
  scratch_shapes=[
    plgpu.RefUnion(
      plgpu.SMEM((256,), jnp.float32),
      [
        plgpu.SMEM((128,), jnp.float32),
        plgpu.SMEM((128,), jnp.float32),
      ],
    )
  ],
)
def kernel(x_ref, o_ref128, aliased_ref):
  smem_ref256, _, smem_ref128 = aliased_ref
  smem_ref256[...] = x_ref[...] + 1
  plgpu.commit_smem()
  plgpu.copy_smem_to_gmem(smem_ref128, o_ref128)
```

PiperOrigin-RevId: 750624152
…issue example is already varying on `x` and we were dropping that when we called into `lax_internal._one`

Fixes jax-ml#28193

PiperOrigin-RevId: 750634402
Having the directory structure of the jaxlib wheel be different to the source tree confuses type checkers such as mypy, since sometimes they find type stubs in the installed jaxlib wheel, and sometimes from the installed source tree.

Instead:
* don't include type stubs in the jaxlib wheel
* don't install the jaxlib wheel as part of pre-commit
* make sure that the location of type stubs (and the underlying libraries) is in the same position in the `jaxlib/` directory of the JAX source tree as it would be for the jaxlib wheel when installed.

For now, we leave some stubs that forward from the old locations to the new locations for certain headers and modules. These will be removed after migrating some users.

PiperOrigin-RevId: 750650528
PiperOrigin-RevId: 750652690
Renaming only, no functional changes intended.

There are two reasons to do this:
* I want to split some XLA specific things out of the JAX wheel and move them back into the XLA repository. It would be nice if the name "xla" could be reserved for that extension instead.
* There are lots of jax-specific things in this extension.

PiperOrigin-RevId: 750709831
PiperOrigin-RevId: 750722897
PiperOrigin-RevId: 750732175
…es aren't removed in sharding propagation.

PiperOrigin-RevId: 750739094
…_p from `check_rep` to `check_vma`.

PiperOrigin-RevId: 750741689
… param->value dict

The parameters must be specified via a dataclass or a mapping from a backend to the
corresponding dataclass.

PiperOrigin-RevId: 750750391
…nd also change docs to point to `jax.shard_map`

PiperOrigin-RevId: 750760353
PiperOrigin-RevId: 750907599
…pgroup logic in the dialect lowering.

The `DialectBarrierRef` class has the same interface as `BarrierRef`, but uses mgpu ops for initialization and `expect_arrive_tx`. This makes the IR cleaner and also allows us to take care of adjusting arrival counts and bytes in the dialect lowering. That makes the high-level code cleaner.

The new lowering always has all threads in a warpgroup arrive when using WG semantics. The behavior so far was to have only a single thread arrive, but keeping this would have complicated things going forward.

The existing tests (including the one that's no longer skipped) test the new behavior.

PiperOrigin-RevId: 750948900
Google-ML-Automation and others added 28 commits May 1, 2025 06:10
I don't expect that we actually need the device ordinal to be defined on the execution context, but we can add it back in statically (it's already decoded in the handler) if necessary.

PiperOrigin-RevId: 753573598
This is quite helpful while trying to debug the load/store routines.

PiperOrigin-RevId: 753599963
… Layout API rename!

PiperOrigin-RevId: 753636449
…t. So `1.0:f32` -> `1.0:f32[]`

PiperOrigin-RevId: 753640777
This issue occurs when some of the leaves have custom `__eq__` methods defined on them, which either result in errors when compared to some other types (see http://cl/753579906), or result in return values that cannot have their truthiness evaluated, e.g.:

```
import jax.tree_util as jtu
import numpy as np

jtu.all_leaves(
  [[np.asarray([1, 2])]],
  is_leaf=lambda x: jtu.all_leaves([x]),
)
```

```
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
```

This fix avoids equality issues by using the `is` operator instead of `==`, and introduces tests for the case where `is_leaf` is provided.

PiperOrigin-RevId: 753684035
Useful for filtering events by function name or differentiating between events.

PiperOrigin-RevId: 753695215
… in_axes, out_axes, axis_name)`. This change does NOT make the API public.

The API semantics are as follows:
* `smap` only allows going into `Manual` mode one mesh axes at a time via the `axis_name` argument.

* mesh needs to be present in the context via `use_mesh` or `set_mesh`.

* If in_axes or out_axes contains `None`, it means that the input(s) is **replicated**. This is similar to `vmap` where `None` means unmapped input.

* If the context mesh is in full explicit mode, `in_axes` can be inferred from the arguments. But how do we tell `smap` to do that? We **can't** use `None` because `None` means replicated in `smap`.
So we introduce a singleton called `Infer` which when passed to `smap`, will tell it to infer the in_axes (in_specs) from the arguments! For example: `smap(f, in_axes=Infer, out_axes=0, axis_name='x')`.
You always have the option of specifying `in_axes` and not infer even in full explicit mode :)

PiperOrigin-RevId: 753695446
Updates LLVM usage to match
[7752e0a10b25](llvm/llvm-project@7752e0a10b25)

PiperOrigin-RevId: 753710403
The extraction of `.tar` files is 10 times faster than the extraction of `.tar.xz` files. By enabling `.tar` files usage in RBE jobs we are going to save at least one min of execution time in all Bazel RBE GPU jobs.

PiperOrigin-RevId: 753730448
PiperOrigin-RevId: 753737826
The XLA GPU runtime does not yet handle device assertions well and will hang
if the assert is triggered. However, the assertion output still appears in
stderr, so I think having `cf.assert` support is still useful.

PiperOrigin-RevId: 753742121
PiperOrigin-RevId: 753786660
… just forward out_sharding to their lax variants.

PiperOrigin-RevId: 753797017
PiperOrigin-RevId: 753797982
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner May 2, 2025 06:02
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot enabled auto-merge (rebase) May 2, 2025 06:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.