forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 4
CI: 05/02/25 upstream sync #399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rocm-repo-management-api-2
wants to merge
1,191
commits into
rocm-main
Choose a base branch
from
ci-upstream-sync-185_1
base: rocm-main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…matting There are no restrictions on window sizes in that backend. Also, replace Markdown quotations with Note/Warning blocks in the GPU reference for added clarity. PiperOrigin-RevId: 750555285
http://github.com/openxla/xla/commit/99b7c3bf05c3877c70ad587439b7481889810564. PiperOrigin-RevId: 750569770
PiperOrigin-RevId: 750570499
PiperOrigin-RevId: 750574686
PiperOrigin-RevId: 750575644
PiperOrigin-RevId: 750600271
At the moment mypy isn't correctly detecting errors related to jaxlib. In a future change this will be fixed, and this PR fixes errors that will be revealed by that change. PiperOrigin-RevId: 750603531
PiperOrigin-RevId: 750619702
Users should now be able to instantiate aliases `smem` buffers by using an `RefUnion`, which takes a variadic number of trees of refs as an input. `RefUnion` represents a union/coproduct of all its operands, and its operands groups alias (overlap in memory), while the elements within the groups represent products, and their operands are consecutive in memory. The resulting aliased `smem` ref can then be unfolded into a flat structure using assignment inside the kernel. Here is an example: ``` @functools.partial( pallas_call, out_shape=jax.ShapeDtypeStruct([128], jnp.float32), in_specs=[pl.BlockSpec((256,))], out_specs=pl.BlockSpec((128,), memory_space=plgpu.GMEM), scratch_shapes=[ plgpu.RefUnion( plgpu.SMEM((256,), jnp.float32), [ plgpu.SMEM((128,), jnp.float32), plgpu.SMEM((128,), jnp.float32), ], ) ], ) def kernel(x_ref, o_ref128, aliased_ref): smem_ref256, _, smem_ref128 = aliased_ref smem_ref256[...] = x_ref[...] + 1 plgpu.commit_smem() plgpu.copy_smem_to_gmem(smem_ref128, o_ref128) ``` PiperOrigin-RevId: 750624152
…issue example is already varying on `x` and we were dropping that when we called into `lax_internal._one` Fixes jax-ml#28193 PiperOrigin-RevId: 750634402
PiperOrigin-RevId: 750637693
PiperOrigin-RevId: 750637751
Having the directory structure of the jaxlib wheel be different to the source tree confuses type checkers such as mypy, since sometimes they find type stubs in the installed jaxlib wheel, and sometimes from the installed source tree. Instead: * don't include type stubs in the jaxlib wheel * don't install the jaxlib wheel as part of pre-commit * make sure that the location of type stubs (and the underlying libraries) is in the same position in the `jaxlib/` directory of the JAX source tree as it would be for the jaxlib wheel when installed. For now, we leave some stubs that forward from the old locations to the new locations for certain headers and modules. These will be removed after migrating some users. PiperOrigin-RevId: 750650528
PiperOrigin-RevId: 750681923
Renaming only, no functional changes intended. There are two reasons to do this: * I want to split some XLA specific things out of the JAX wheel and move them back into the XLA repository. It would be nice if the name "xla" could be reserved for that extension instead. * There are lots of jax-specific things in this extension. PiperOrigin-RevId: 750709831
PiperOrigin-RevId: 750722897
PiperOrigin-RevId: 750725054
PiperOrigin-RevId: 750732175
…es aren't removed in sharding propagation. PiperOrigin-RevId: 750739094
…_p from `check_rep` to `check_vma`. PiperOrigin-RevId: 750741689
… param->value dict The parameters must be specified via a dataclass or a mapping from a backend to the corresponding dataclass. PiperOrigin-RevId: 750750391
…nd also change docs to point to `jax.shard_map` PiperOrigin-RevId: 750760353
…changelog PiperOrigin-RevId: 750900798
PiperOrigin-RevId: 750907599
…pgroup logic in the dialect lowering. The `DialectBarrierRef` class has the same interface as `BarrierRef`, but uses mgpu ops for initialization and `expect_arrive_tx`. This makes the IR cleaner and also allows us to take care of adjusting arrival counts and bytes in the dialect lowering. That makes the high-level code cleaner. The new lowering always has all threads in a warpgroup arrive when using WG semantics. The behavior so far was to have only a single thread arrive, but keeping this would have complicated things going forward. The existing tests (including the one that's no longer skipped) test the new behavior. PiperOrigin-RevId: 750948900
PiperOrigin-RevId: 750954967
http://github.com/openxla/xla/commit/15565b8da6d85e9faec669cb22878a0e44cca4ee. PiperOrigin-RevId: 753562330
I don't expect that we actually need the device ordinal to be defined on the execution context, but we can add it back in statically (it's already decoded in the handler) if necessary. PiperOrigin-RevId: 753573598
This is quite helpful while trying to debug the load/store routines. PiperOrigin-RevId: 753599963
… Layout API rename! PiperOrigin-RevId: 753636449
PiperOrigin-RevId: 753639002
…t. So `1.0:f32` -> `1.0:f32[]` PiperOrigin-RevId: 753640777
This issue occurs when some of the leaves have custom `__eq__` methods defined on them, which either result in errors when compared to some other types (see http://cl/753579906), or result in return values that cannot have their truthiness evaluated, e.g.: ``` import jax.tree_util as jtu import numpy as np jtu.all_leaves( [[np.asarray([1, 2])]], is_leaf=lambda x: jtu.all_leaves([x]), ) ``` ``` ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() ``` This fix avoids equality issues by using the `is` operator instead of `==`, and introduces tests for the case where `is_leaf` is provided. PiperOrigin-RevId: 753684035
PiperOrigin-RevId: 753685498
PiperOrigin-RevId: 753689814
Useful for filtering events by function name or differentiating between events. PiperOrigin-RevId: 753695215
… in_axes, out_axes, axis_name)`. This change does NOT make the API public. The API semantics are as follows: * `smap` only allows going into `Manual` mode one mesh axes at a time via the `axis_name` argument. * mesh needs to be present in the context via `use_mesh` or `set_mesh`. * If in_axes or out_axes contains `None`, it means that the input(s) is **replicated**. This is similar to `vmap` where `None` means unmapped input. * If the context mesh is in full explicit mode, `in_axes` can be inferred from the arguments. But how do we tell `smap` to do that? We **can't** use `None` because `None` means replicated in `smap`. So we introduce a singleton called `Infer` which when passed to `smap`, will tell it to infer the in_axes (in_specs) from the arguments! For example: `smap(f, in_axes=Infer, out_axes=0, axis_name='x')`. You always have the option of specifying `in_axes` and not infer even in full explicit mode :) PiperOrigin-RevId: 753695446
PiperOrigin-RevId: 753705149
PiperOrigin-RevId: 753707559
Updates LLVM usage to match [7752e0a10b25](llvm/llvm-project@7752e0a10b25) PiperOrigin-RevId: 753710403
The extraction of `.tar` files is 10 times faster than the extraction of `.tar.xz` files. By enabling `.tar` files usage in RBE jobs we are going to save at least one min of execution time in all Bazel RBE GPU jobs. PiperOrigin-RevId: 753730448
PiperOrigin-RevId: 753732387
PiperOrigin-RevId: 753734380
PiperOrigin-RevId: 753737826
The XLA GPU runtime does not yet handle device assertions well and will hang if the assert is triggered. However, the assertion output still appears in stderr, so I think having `cf.assert` support is still useful. PiperOrigin-RevId: 753742121
PiperOrigin-RevId: 753758114
PiperOrigin-RevId: 753786660
… just forward out_sharding to their lax variants. PiperOrigin-RevId: 753797017
PiperOrigin-RevId: 753797982
PiperOrigin-RevId: 753840368
PiperOrigin-RevId: 753859510
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream