-
Notifications
You must be signed in to change notification settings - Fork 3k
[Pallas] [Mosaic GPU] Add GPU pipelining docs #28135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
a2708ed
to
6485c7f
Compare
docs/pallas/gpu/pipelining.md
Outdated
|
||
- `body`, `grid` have the same semantics as in `pl.pallas_call`. The `grid` denotes how many invocations of the `body` function to run. In contrast with a CUDA grid, the pipeline grid is guaranteed to run sequentially. | ||
- `in_specs` and `out_specs` also work the same as `pl.pallas_call`, except they also accept `plgpu.GPUBlockSpec` instances that can be used specify GPU-specific transforms, such as swizzling. See [memory reference transforms](https://docs.jax.dev/en/latest/pallas/gpu/reference.html#memory-reference-transforms) for more detail on available transformations. | ||
- `max_concurrent_steps` controls the maximum number of pipeline stages to use. Using additional stages will consume more SMEM to hold temporary buffers, so this option should be used carefully. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if "pipeline stages" is quite right here. Should we explain this in terms of the number of copy operations running in concurrently?
CC @apaszke
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated this to just say the number of concurrent memory copies in flight.
docs/pallas/gpu/pipelining.md
Outdated
- `delay_release` allows the user to specify the number of steps to wait before re-using a buffer. This is useful for certain optimizations such as allowing multiple async matmuls in flight to keep the tensor core pipeline filled. | ||
|
||
|
||
As an alternative to `emit_pipeline`, Mosaic GPU also implements the existing `pl.pallas_call` API for pipelining. Pipelining with `pl.pallas_call` directly requires the user to pass in a `plgpu.GPUCompilerParams` object as the `compiler_params` argument, which specifies the following options that are relevant for pipelining: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would personally just leave this out tbh and tell people to use plgpu.kernel
, since it allows you to do everything pl.pallas_call
can and more (if you need to).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll mark this as a "Compatibility" API only and nudge people to use plgpu.kernel
instead? I think some users might prefer having "one way to do things" and use pallas_call for both GPU/TPU.
9f69f60
to
b33dd7d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
b33dd7d
to
ed3a842
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the picture is a bit misleading in that it makes it seem as if the WS schedule was actually worse? Perhaps compact the space in WG 1 so that while WG0 is doing copy_start, WG1 is doing copy_wait and matmul? Also, consider making the pipelined copy_waits short since the whole point is that we should no longer wait this much for memory if the transaction had enough time to complete?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I compacted the figure so it's the two are mostly aligned in time.
So the memory thread spends most of the time on the consumed_barrier_wait, and compute_thread on matmul/copy_wait
@@ -0,0 +1,325 @@ | |||
--- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just OOC, is there any benefit for us in keeping this guild in both ipynb and as well .md? The only thing I can think of is that people might be able to run it on colab, but if it's mostly an explanation with not that much self-contained code then I'm not sure if it's useful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with just keeping the md, but it's also not really any extra work to have the ipynb and there's a couple runnable examples here for the matmuls.
name: python3 | ||
--- | ||
|
||
(pallas_mgpu_pipelining)= |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOC is that some way to create links to docs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you can refer back to this with {ref}`pallas_mgpu_pipelining`
docs/pallas/gpu/pipelining.md
Outdated
<!-- #region id="OkWmfqn7b53M" --> | ||
We use the `carry_coroutine` pattern to initialize the WGMMA accumulator, and copy the final accumulator from registers into SMEM. Here, the carry coroutine is defined in the function `compute_thread`. It is critical that the accumulator be created inside of the `compute_thread` function to avoid allocating it in the memory warpgroup which would waste registers. To perform the. WGMMA, we wrap the `wgmma` instruction in a `pl.run_state` in order to create an accumulator ref that is initialized to the carry value. | ||
|
||
Instead of using `pl.pallas_call` to call the kernel, we instead use the GPU-specific `plgpu.kernel` entry point. `plgpu.kernel` allows us to specify the number of warpgroups to launch per CUDA block via the `num_threads` argument, and allows us to specify a `thread_name` we can use to query the warpgroup index inside of the kernel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please please please stop talking about warpgroups this much. It can be helpful to mention it from time to time, but we really should talk more The warpgroup really is not visible and not a concept that makes sense in the semantics of Pallas:MGPU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll refer to this as "Mosaic thread" since there's potential for ambiguity between "CUDA thread" and "Mosaic thread" as Christos noted above.
Adds docs covering emit_pipeline and emit_pipeline_warp_specialized.
Remaining TODOs possibly include having a flash attention example with ping-pong scheduling.