Skip to content

Commit d8ccff5

Browse files
committed
Add hl.wait & AllGather Matmul example (ptx impl).
stack-info: PR: #189, branch: joydddd/stack/5
1 parent feb86dc commit d8ccff5

File tree

4 files changed

+530
-0
lines changed

4 files changed

+530
-0
lines changed

examples/all_gather_matmul.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from typing import Any
5+
6+
import torch
7+
import torch.distributed as dist
8+
import torch.distributed._symmetric_memory as symm_mem
9+
10+
import helion
11+
import helion.language as hl
12+
13+
14+
def copy_engine_all_gather_w_progress(
15+
output: torch.Tensor,
16+
inp: torch.Tensor, # Must be symmetric tensor
17+
progress: torch.Tensor,
18+
splits_per_rank: int,
19+
backend_stream: torch.cuda.Stream | None = None,
20+
) -> torch.cuda.Stream:
21+
backend_stream = symm_mem._get_backend_stream(priority=-1)
22+
assert inp.is_contiguous()
23+
symm_mem_group = dist.group.WORLD
24+
if symm_mem_group is None:
25+
raise RuntimeError("No symmetric memory group available")
26+
symm_mem_hdl = symm_mem.rendezvous(inp, group=symm_mem_group)
27+
assert symm_mem_hdl is not None
28+
29+
rank = symm_mem_hdl.rank
30+
world_size = symm_mem_hdl.world_size
31+
32+
assert inp.numel() % splits_per_rank == 0
33+
assert progress.numel() >= world_size * splits_per_rank
34+
35+
output_shape = list(inp.shape)
36+
output_shape[0] *= world_size
37+
assert list(output.shape) == output_shape, (list(output.shape), output_shape)
38+
39+
chunks = output.chunk(world_size * splits_per_rank)
40+
41+
symm_mem_hdl.barrier()
42+
backend_stream.wait_stream(torch.cuda.current_stream())
43+
44+
with torch.cuda.stream(backend_stream):
45+
for step in range(world_size):
46+
src_rank = (rank + step + 1) % world_size
47+
for split_id in range(splits_per_rank):
48+
src_buf = symm_mem_hdl.get_buffer(
49+
src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id
50+
)
51+
chunks[src_rank * splits_per_rank + split_id].copy_(src_buf)
52+
# cuStreamWriteValue32 issues a system level fence before the write
53+
symm_mem_hdl.stream_write_value32(
54+
progress,
55+
offset=src_rank * splits_per_rank + split_id,
56+
val=1,
57+
)
58+
symm_mem_hdl.barrier()
59+
60+
return backend_stream
61+
62+
63+
@helion.jit(
64+
config=helion.Config(
65+
block_sizes=[128, 256, 64],
66+
num_warps=8,
67+
num_stages=3,
68+
indexing="block_ptr",
69+
),
70+
# Static shapes provides a speedup for attention
71+
static_shapes=True,
72+
)
73+
def helion_matmul_w_progress(
74+
a: torch.Tensor,
75+
a_shared: torch.Tensor,
76+
b: torch.Tensor,
77+
progress: torch.Tensor,
78+
SPLITS_PER_RANK: int,
79+
RANK: int,
80+
) -> torch.Tensor:
81+
M, K = a.size()
82+
K2, N = b.size()
83+
assert K2 == K, f"size mismatch {K2} != {K}"
84+
85+
out = torch.empty(
86+
[M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device
87+
)
88+
89+
M_per_rank = a_shared.size(0)
90+
91+
for tile_m, tile_n in hl.tile([M, N]):
92+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
93+
hl.wait(
94+
progress,
95+
[
96+
tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
97+
],
98+
signal=1,
99+
update=None,
100+
op="ld",
101+
scope="gpu",
102+
sem="acquire",
103+
)
104+
for tile_k in hl.tile(K):
105+
# TODO(joydddd): use a_shared and skipp barrier when data is available on local rank.
106+
# if tile_k.begin // M_per_rank == RANK:
107+
# acc = torch.addmm(acc, a_shared[tile_m.index - RANK * M_per_rank, tile_k], b[tile_k, tile_n])
108+
# else:
109+
# hl.wait(progress, [tile_m.begin // (M_per_rank // SPLITS_PER_RANK),], signal=1, update=None, op="ld", scope="gpu", sem="acquire")
110+
acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
111+
out[tile_m, tile_n] = acc
112+
return out
113+
114+
115+
def helion_all_gather_matmul(
116+
a_shared: torch.Tensor,
117+
b: torch.Tensor,
118+
a_out: torch.Tensor | None = None,
119+
progress: torch.Tensor | None = None,
120+
**kwargs: Any,
121+
) -> tuple[torch.Tensor, torch.Tensor]:
122+
configs = {
123+
"SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1),
124+
"BLOCK_SIZE_M": kwargs.get("block_size_m", 128),
125+
"BLOCK_SIZE_N": kwargs.get("block_size_n", 256),
126+
"BLOCK_SIZE_K": kwargs.get("block_size_k", 64),
127+
"GROUP_SIZE_M": kwargs.get("group_size_m", 4),
128+
"num_stages": kwargs.get("num_stages", 3),
129+
"num_warps": kwargs.get("num_warps", 8),
130+
}
131+
132+
symm_mem_group = dist.group.WORLD
133+
if symm_mem_group is None:
134+
raise RuntimeError("No symmetric memory group available")
135+
136+
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=symm_mem_group)
137+
138+
a_shape = list(a_shared.shape)
139+
a_shape[0] *= symm_mem_hdl.world_size
140+
141+
configs["RANK"] = symm_mem_hdl.rank
142+
configs["WORLD_SIZE"] = symm_mem_hdl.world_size
143+
if (
144+
configs["SPLITS_PER_RANK"]
145+
* configs["WORLD_SIZE"]
146+
* configs["BLOCK_SIZE_M"]
147+
* configs["GROUP_SIZE_M"]
148+
> a_shape[0]
149+
):
150+
configs["GROUP_SIZE_M"] = 1
151+
configs["SPLITS_PER_RANK"] = 1
152+
153+
configs["COMM_BLOCK_SIZE_M"] = (
154+
a_shape[0] // configs["WORLD_SIZE"] // configs["SPLITS_PER_RANK"]
155+
)
156+
assert (
157+
configs["COMM_BLOCK_SIZE_M"]
158+
% (configs["BLOCK_SIZE_M"] * configs["GROUP_SIZE_M"])
159+
== 0
160+
)
161+
162+
if a_out is None:
163+
a_out = torch.empty(a_shape, dtype=a_shared.dtype, device=a_shared.device)
164+
165+
if progress is None:
166+
progress = torch.zeros(
167+
symm_mem_hdl.world_size * configs["SPLITS_PER_RANK"],
168+
dtype=torch.uint32,
169+
device=a_shared.device,
170+
)
171+
else:
172+
progress.fill_(
173+
0
174+
) # Reset progress to 0. Maybe we should reset inside the kernel using cas?
175+
176+
backend_stream = copy_engine_all_gather_w_progress(
177+
a_out, a_shared, progress, configs["SPLITS_PER_RANK"]
178+
)
179+
180+
c = helion_matmul_w_progress(
181+
a_out,
182+
a_shared,
183+
b,
184+
progress,
185+
SPLITS_PER_RANK=configs["SPLITS_PER_RANK"],
186+
RANK=configs["RANK"],
187+
)
188+
assert type(c) is torch.Tensor
189+
190+
torch.cuda.current_stream().wait_stream(backend_stream)
191+
192+
return a_out, c
193+
194+
195+
def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None:
196+
a_shared = symm_mem.empty(
197+
M // world_size, K, dtype=torch.bfloat16, device=device
198+
).normal_()
199+
b = torch.randn((K, N), device="cuda", dtype=torch.bfloat16).T.contiguous().T
200+
201+
a_out, c = helion_all_gather_matmul(a_shared, b)
202+
203+
golden_a = a_shared.clone()
204+
dist_group = dist.group.WORLD
205+
if dist_group is None:
206+
raise RuntimeError("No distributed group available")
207+
ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul(
208+
golden_a, [b], gather_dim=0, group_name=dist_group.group_name
209+
)
210+
torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1)
211+
torch.testing.assert_close(a_out, ag_golden)
212+
213+
214+
def main() -> None:
215+
rank = int(os.environ["LOCAL_RANK"])
216+
world_size = int(os.environ["WORLD_SIZE"])
217+
torch.manual_seed(42 + rank)
218+
device = torch.device(f"cuda:{rank}")
219+
torch.cuda.set_device(device)
220+
dist.init_process_group("nccl")
221+
test(4096, 6656, 16384, world_size, device)
222+
223+
dist.destroy_process_group()
224+
225+
226+
if __name__ == "__main__":
227+
"""
228+
torchrun \
229+
--nnodes 1 --nproc-per-node 8 \
230+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
231+
--no_python python3 examples/all_gather_matmul.py
232+
"""
233+
main()

helion/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .memory_ops import atomic_add as atomic_add
1111
from .memory_ops import load as load
1212
from .memory_ops import store as store
13+
from .signal_wait import wait as wait
1314
from .tile_ops import tile_begin as tile_begin
1415
from .tile_ops import tile_block_size as tile_block_size
1516
from .tile_ops import tile_end as tile_end

0 commit comments

Comments
 (0)