Skip to content

[Pallas] [Mosaic GPU] Add GPU pipelining docs #28135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

justinjfu
Copy link
Collaborator

@justinjfu justinjfu commented Apr 19, 2025

Adds docs covering emit_pipeline and emit_pipeline_warp_specialized.

Remaining TODOs possibly include having a flash attention example with ping-pong scheduling.


- `body`, `grid` have the same semantics as in `pl.pallas_call`. The `grid` denotes how many invocations of the `body` function to run. In contrast with a CUDA grid, the pipeline grid is guaranteed to run sequentially.
- `in_specs` and `out_specs` also work the same as `pl.pallas_call`, except they also accept `plgpu.GPUBlockSpec` instances that can be used specify GPU-specific transforms, such as swizzling. See [memory reference transforms](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#memory-reference-transforms) for more detail on available transformations.
- `max_concurrent_steps` controls the maximum number of pipeline stages to use. Using additional stages will consume more SMEM to hold temporary buffers, so this option should be used carefully.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if "pipeline stages" is quite right here. Should we explain this in terms of the number of copy operations running in concurrently?

CC @apaszke

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this to just say the number of concurrent memory copies in flight.

- `delay_release` allows the user to specify the number of steps to wait before re-using a buffer. This is useful for certain optimizations such as allowing multiple async matmuls in flight to keep the tensor core pipeline filled.


As an alternative to `emit_pipeline`, Mosaic GPU also implements the existing `pl.pallas_call` API for pipelining. Pipelining with `pl.pallas_call` directly requires the user to pass in a `plgpu.GPUCompilerParams` object as the `compiler_params` argument, which specifies the following options that are relevant for pipelining:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would personally just leave this out tbh and tell people to use plgpu.kernel, since it allows you to do everything pl.pallas_call can and more (if you need to).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll mark this as a "Compatibility" API only and nudge people to use plgpu.kernel instead? I think some users might prefer having "one way to do things" and use pallas_call for both GPU/TPU.

@justinjfu justinjfu force-pushed the gpu_pipe_docs branch 2 times, most recently from 9f69f60 to b33dd7d Compare April 22, 2025 21:24
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 25, 2025
Copy link
Contributor

@cperivol cperivol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

@apaszke apaszke self-assigned this Apr 29, 2025
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the picture is a bit misleading in that it makes it seem as if the WS schedule was actually worse? Perhaps compact the space in WG 1 so that while WG0 is doing copy_start, WG1 is doing copy_wait and matmul? Also, consider making the pipelined copy_waits short since the whole point is that we should no longer wait this much for memory if the transaction had enough time to complete?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I compacted the figure so it's the two are mostly aligned in time.

So the memory thread spends most of the time on the consumed_barrier_wait, and compute_thread on matmul/copy_wait

@@ -0,0 +1,325 @@
---
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just OOC, is there any benefit for us in keeping this guild in both ipynb and as well .md? The only thing I can think of is that people might be able to run it on colab, but if it's mostly an explanation with not that much self-contained code then I'm not sure if it's useful?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with just keeping the md, but it's also not really any extra work to have the ipynb and there's a couple runnable examples here for the matmuls.

name: python3
---

(pallas_mgpu_pipelining)=
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC is that some way to create links to docs?

Copy link
Collaborator Author

@justinjfu justinjfu May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you can refer back to this with {ref}`pallas_mgpu_pipelining`

<!-- #region id="OkWmfqn7b53M" -->
We use the `carry_coroutine` pattern to initialize the WGMMA accumulator, and copy the final accumulator from registers into SMEM. Here, the carry coroutine is defined in the function `compute_thread`. It is critical that the accumulator be created inside of the `compute_thread` function to avoid allocating it in the memory warpgroup which would waste registers. To perform the. WGMMA, we wrap the `wgmma` instruction in a `pl.run_state` in order to create an accumulator ref that is initialized to the carry value.

Instead of using `pl.pallas_call` to call the kernel, we instead use the GPU-specific `plgpu.kernel` entry point. `plgpu.kernel` allows us to specify the number of warpgroups to launch per CUDA block via the `num_threads` argument, and allows us to specify a `thread_name` we can use to query the warpgroup index inside of the kernel.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please please please stop talking about warpgroups this much. It can be helpful to mention it from time to time, but we really should talk more The warpgroup really is not visible and not a concept that makes sense in the semantics of Pallas:MGPU

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll refer to this as "Mosaic thread" since there's potential for ambiguity between "CUDA thread" and "Mosaic thread" as Christos noted above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants