Skip to content

Commit edac1b1

Browse files
committed
[Pallas][Mosaic GPU] Add GPU pipelining docs
1 parent 1727657 commit edac1b1

9 files changed

+774
-6
lines changed

docs/_static/pallas/gpu/pipeline_matmul.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/_static/pallas/gpu/pipeline_matmul_ws.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/_static/pallas/gpu/warp_specialization.svg

Lines changed: 1 addition & 0 deletions
Loading

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _do_not_evaluate_in_jax(
134134
'notebooks/*.md',
135135
'pallas/quickstart.md',
136136
'pallas/pipelining.md',
137+
'pallas/gpu/pipelining.md',
137138
'pallas/tpu/pipelining.md',
138139
'pallas/tpu/distributed.md',
139140
'pallas/tpu/sparse.md',
@@ -230,6 +231,7 @@ def _do_not_evaluate_in_jax(
230231
# Requires accelerators
231232
'pallas/quickstart.*',
232233
'pallas/pipelining.*',
234+
'pallas/gpu/pipelining.*',
233235
'pallas/tpu/pipelining.*',
234236
'pallas/tpu/distributed.*',
235237
'pallas/tpu/sparse.*',

docs/pallas/gpu/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Backend specific documentation for the Mosaic GPU backend.
77
:maxdepth: 2
88

99
reference
10+
pipelining
1011

1112
.. toctree::
1213
:caption: Guides

docs/pallas/gpu/pipelining.ipynb

Lines changed: 428 additions & 0 deletions
Large diffs are not rendered by default.

docs/pallas/gpu/pipelining.md

Lines changed: 332 additions & 0 deletions
Large diffs are not rendered by default.

docs/pallas/pipelining.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"\n",
1414
"Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API.\n",
1515
"\n",
16-
"This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or GPU (coming soon!) specific pipelining references.\n"
16+
"This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or {ref}`pallas_mgpu_pipelining`.\n"
1717
]
1818
},
1919
{
@@ -853,6 +853,7 @@
853853
"provenance": []
854854
},
855855
"jupytext": {
856+
"formats": "ipynb,md",
856857
"main_language": "python"
857858
},
858859
"kernelspec": {

docs/pallas/pipelining.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
---
22
jupyter:
33
jupytext:
4+
formats: ipynb,md
45
main_language: python
56
text_representation:
67
extension: .md
@@ -20,7 +21,7 @@ jupyter:
2021

2122
Software pipelining is an important technique in performance optimization by overlapping multiple asynchronous operations even if there are data dependencies between them. In the context of kernel writing, the most common form of pipelining involves overlapping communication and memory transfers with compute such that the hardware accelerator never stalls while waiting for data to arrive. Therefore, we will solely focus on the problem of communication-compute pipelining in this tutorial. We will begin by covering the problem conceptually, outlining the Pallas API for writing pipelines, and going over some realistic examples using the API.
2223

23-
This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or GPU (coming soon!) specific pipelining references.
24+
This tutorial only covers the conceptual foundations of pipelining. For platform-specific references, please see {ref}`pallas_tpu_pipelining`, or {ref}`pallas_mgpu_pipelining`.
2425

2526
<!-- #endregion -->
2627

@@ -63,7 +64,7 @@ In order to perform computation on values X and Y that live in HBM, we need to:
6364
Let’s implement a Pallas function that does just that!
6465
<!-- #endregion -->
6566

66-
```python id="IrPhDFnT3Nvw" executionInfo={"status": "ok", "timestamp": 1744764235906, "user_tz": 420, "elapsed": 108, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="8bc03872-fd9f-4610-9d53-d4b46be560f4"
67+
```python executionInfo={"elapsed": 108, "status": "ok", "timestamp": 1744764235906, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="IrPhDFnT3Nvw" outputId="8bc03872-fd9f-4610-9d53-d4b46be560f4"
6768
# Note: This is a TPU example.
6869

6970
def add_matrices_kernel(x_sram_ref, y_sram_ref, z_sram_ref):
@@ -480,7 +481,7 @@ As a concrete example, let's consider performing the following computation for r
480481

481482
<!-- #endregion -->
482483

483-
```python id="4qz1ET-_f9fJ" executionInfo={"status": "ok", "timestamp": 1744763773938, "user_tz": 420, "elapsed": 244, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="e43067ef-933a-45a5-912a-e224151cfa60"
484+
```python executionInfo={"elapsed": 244, "status": "ok", "timestamp": 1744763773938, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="4qz1ET-_f9fJ" outputId="e43067ef-933a-45a5-912a-e224151cfa60"
484485
x = jnp.ones((8, 1024, 1024))
485486
jnp.sum(x, axis=0)
486487
```
@@ -489,7 +490,7 @@ jnp.sum(x, axis=0)
489490
To do this using `pallas_call`, we could use a grid of size `(8,)` and in each iteration i load `x[i]` into SRAM. Then we could add `x[i]` to an output SRAM buffer. Let's implement this naively first.
490491
<!-- #endregion -->
491492

492-
```python id="ZEi1_vQVf-81" executionInfo={"status": "ok", "timestamp": 1744763774254, "user_tz": 420, "elapsed": 79, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="581744b7-ddc1-4dc1-98ec-03c852772eda"
493+
```python executionInfo={"elapsed": 79, "status": "ok", "timestamp": 1744763774254, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="ZEi1_vQVf-81" outputId="581744b7-ddc1-4dc1-98ec-03c852772eda"
493494
# Note: This is a TPU example.
494495

495496
# Warning: this implementation is incorrect!
@@ -521,7 +522,7 @@ There are two errors inside this kernel. First, we are accumulating along the fi
521522
After fixing these two issues, we obtain the following corrected kernel. In this new kernel, we use `@pl.when` to create a conditional that checks when the program ID is `0` along the reduction axis, indicating we are beginning to accumulate into a new output block. We have also moved the reduction dimension to the last axis of the `grid`.
522523
<!-- #endregion -->
523524

524-
```python id="XtgD4nMa9_Bd" executionInfo={"status": "ok", "timestamp": 1744763774523, "user_tz": 420, "elapsed": 104, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}} outputId="9ef07cdf-9e22-4dc8-c17f-c96172639801"
525+
```python executionInfo={"elapsed": 104, "status": "ok", "timestamp": 1744763774523, "user": {"displayName": "Justin Fu", "userId": "17543197034567316452"}, "user_tz": 420} id="XtgD4nMa9_Bd" outputId="9ef07cdf-9e22-4dc8-c17f-c96172639801"
525526
# Note: This is a TPU example.
526527

527528
def correct_sum_kernel(x_ref, o_ref):

0 commit comments

Comments
 (0)