Skip to content

Commit 5845232

Browse files
authored
[Example] Update AG-GEMM example (#30)
* [Example] Use inplace A * [Example] Update AG_GEMM
1 parent 7df1a99 commit 5845232

File tree

1 file changed

+96
-98
lines changed

1 file changed

+96
-98
lines changed

examples/distributed/example_allgather_gemm_overlapped.py

Lines changed: 96 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -22,99 +22,59 @@
2222

2323

2424
@tilelang.jit(pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True})
25-
def copy_and_barrier_all_intra_node_kernel(local_rank,
26-
rank,
27-
num_ranks,
28-
M,
29-
K,
30-
block_M,
31-
block_K,
32-
threads,
33-
dtype="float16"):
34-
35-
M_per_rank = T.ceildiv(M, num_ranks)
36-
sm_num = driver.get_num_sms()
37-
m_blocks = T.ceildiv(M_per_rank, block_M)
38-
k_blocks = T.ceildiv(K, block_K)
39-
waves = T.ceildiv(m_blocks * k_blocks, sm_num)
40-
41-
@T.macro
42-
def copy_kernel(src: T.Tensor((M_per_rank, K), dtype), dst: T.Tensor((M, K), dtype),
43-
data_shared: T.Tensor((block_M, block_K), dtype), block_id):
44-
for w in T.serial(waves):
45-
tile_id = sm_num * w + block_id
46-
bx = tile_id % m_blocks
47-
by = tile_id // m_blocks
48-
49-
if by < k_blocks:
50-
T.copy(src[bx * block_M, by * block_K], data_shared)
51-
T.copy(data_shared, dst[rank * M_per_rank + bx * block_M, by * block_K])
52-
53-
@T.macro
54-
def barrier_all_intra_node_non_atomic(
55-
sync_buffer: T.Tensor((3 * num_ranks), "uint32"), block_id):
56-
if block_id == 0:
57-
T.barrier_all_blocks_sys(sync_buffer)
58-
# barrier all CTAs
59-
T.sync_grid(sync_buffer[2 * num_ranks])
25+
def set_signal_kernel(local_rank, num_local_ranks, threads):
6026

6127
@T.prim_func
62-
def local_copy(
63-
A: T.Tensor((M_per_rank, K), dtype),
64-
ag_buffer: T.Tensor((M, K), dtype),
65-
signal_buffer: T.Tensor((num_ranks), "uint32"),
66-
sync_buffer: T.Tensor((3 * num_ranks), "uint32"),
67-
):
68-
with T.Kernel(sm_num, threads=threads) as (block_id):
69-
data_shared = T.alloc_shared((block_M, block_K), dtype)
70-
T.annotate_layout({data_shared: tilelang.layout.make_swizzled_layout(data_shared)})
71-
72-
barrier_all_intra_node_non_atomic(sync_buffer, block_id)
73-
copy_kernel(A, ag_buffer, data_shared, block_id)
28+
def _set_signal_kernel(signal_buffer: T.Tensor((num_local_ranks), "uint32"),):
29+
with T.Kernel(1, threads=threads):
7430
tx = T.get_thread_binding(0)
75-
if block_id == 0 and tx < num_ranks: # set symm barrier
76-
if tx == rank:
31+
if tx < num_local_ranks:
32+
if tx == local_rank:
7733
signal_buffer[tx] = 1
7834
else:
7935
signal_buffer[tx] = 0
80-
barrier_all_intra_node_non_atomic(sync_buffer, block_id)
8136

82-
return local_copy
37+
return _set_signal_kernel
8338

8439

8540
@tilelang.jit
8641
def gemm_kernel(M,
8742
N,
8843
K,
89-
num_rank,
9044
local_rank,
45+
num_local_rank,
9146
block_M,
9247
block_N,
9348
block_K,
9449
threads,
50+
persistent=False,
9551
dtype="float16",
9652
accum_dtype="float"):
9753

98-
M_per_rank = T.ceildiv(M, num_rank)
54+
sm_num = driver.get_num_sms()
55+
m_blocks = T.ceildiv(M, block_M)
56+
n_blocks = T.ceildiv(N // num_local_rank, block_N)
57+
waves = T.ceildiv(m_blocks * n_blocks, sm_num)
58+
M_per_rank = T.ceildiv(M, num_local_rank)
9959
GROUP_SIZE_M = 8
10060

10161
@T.prim_func
10262
def main(
10363
A: T.Tensor((M, K), dtype),
104-
B: T.Tensor((K, N // num_rank), dtype),
105-
signal_buffer: T.Tensor((num_rank), "uint32"),
106-
C: T.Tensor((M, N // num_rank), dtype),
64+
B: T.Tensor((K, N // num_local_rank), dtype),
65+
signal_buffer: T.Tensor((num_local_rank), "uint32"),
66+
C: T.Tensor((M, N // num_local_rank), dtype),
10767
):
10868
with T.Kernel(
109-
T.ceildiv(M, block_M) * T.ceildiv(N // num_rank, block_N),
69+
T.ceildiv(M, block_M) * T.ceildiv(N // num_local_rank, block_N),
11070
threads=threads) as (bid):
11171
A_shared = T.alloc_shared((block_M, block_K), dtype)
11272
B_shared = T.alloc_shared((block_K, block_N), dtype)
11373
C_shared = T.alloc_shared((block_M, block_N), dtype)
11474
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
11575

11676
num_pid_m = T.ceildiv(M, block_M)
117-
num_pid_n = T.ceildiv(N // num_rank, block_N)
77+
num_pid_n = T.ceildiv(N // num_local_rank, block_N)
11878
num_pid_in_group = GROUP_SIZE_M * num_pid_n
11979
group_id = bid // num_pid_in_group
12080
first_pid_m = group_id * GROUP_SIZE_M
@@ -140,55 +100,94 @@ def main(
140100
T.copy(C_local, C_shared)
141101
T.copy(C_shared, C[pid_m * block_M, pid_n * block_N])
142102

143-
return main
103+
@T.prim_func
104+
def main_persistent(
105+
A: T.Tensor((M, K), dtype),
106+
B: T.Tensor((K, N // num_local_rank), dtype),
107+
signal_buffer: T.Tensor((num_local_rank), "uint32"),
108+
C: T.Tensor((M, N // num_local_rank), dtype),
109+
):
110+
with T.Kernel(sm_num, threads=threads) as (bid):
111+
A_shared = T.alloc_shared((block_M, block_K), dtype)
112+
B_shared = T.alloc_shared((block_K, block_N), dtype)
113+
C_shared = T.alloc_shared((block_M, block_N), dtype)
114+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
115+
116+
for w in T.serial(waves):
117+
tile_id = bid + w * sm_num
118+
num_pid_m = T.ceildiv(M, block_M)
119+
num_pid_n = T.ceildiv(N // num_local_rank, block_N)
120+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
121+
group_id = tile_id // num_pid_in_group
122+
first_pid_m = group_id * GROUP_SIZE_M
123+
group_size_m = T.min(num_pid_m - first_pid_m, GROUP_SIZE_M)
124+
pid_m_ = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
125+
pid_n_ = (tile_id % num_pid_in_group) // group_size_m
126+
127+
# threadblock swizzle
128+
# no stream-k support. only split by m x n
129+
m_offset = M_per_rank * local_rank
130+
pid_m_offset = T.ceildiv(m_offset, block_M)
131+
pid_m = (pid_m_ + pid_m_offset) % num_pid_m
132+
pid_n = pid_n_
133+
134+
if pid_n_ * block_N < (N // num_local_rank) and pid_m_ * block_M < M:
135+
tid = T.get_thread_binding(0)
136+
T.clear(C_local)
137+
if tid == 0:
138+
T.wait_eq(signal_buffer[pid_m * block_M // M_per_rank], 1)
139+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
140+
T.copy(A[pid_m * block_M, k * block_K], A_shared)
141+
T.copy(B[k * block_K, pid_n * block_N], B_shared)
142+
T.gemm(A_shared, B_shared, C_local)
143+
T.copy(C_local, C_shared)
144+
T.copy(C_shared, C[pid_m * block_M, pid_n * block_N])
145+
146+
return main if not persistent else main_persistent
144147

145148

146149
def cp_engine_producer_all_gather_full_mesh_pull(
147-
local_tensor,
148150
ag_buffer,
149151
signal_buffer,
150152
M_per_rank,
151-
N,
152153
signal_target,
153-
rank,
154+
local_rank,
154155
local_world_size,
155-
world_size,
156156
intranode_ag_stream,
157157
):
158-
rank_orders = [(rank + i) % local_world_size for i in range(local_world_size)]
158+
rank_orders = [(local_rank + i) % local_world_size for i in range(local_world_size)]
159159

160160
with torch.cuda.stream(intranode_ag_stream):
161161
for src_rank in rank_orders:
162-
if src_rank == rank:
162+
if src_rank == local_rank:
163163
continue
164-
dst = ag_buffer[rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :]
164+
dst = ag_buffer[local_rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :]
165165
src = ag_buffer[src_rank][src_rank * M_per_rank:(src_rank + 1) * M_per_rank, :]
166166
dst.copy_(src)
167167

168168
(err,) = cuda.cuStreamWriteValue32(
169169
intranode_ag_stream.cuda_stream,
170-
signal_buffer[rank][src_rank].data_ptr(),
170+
signal_buffer[local_rank][src_rank].data_ptr(),
171171
signal_target,
172172
cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
173173
)
174174

175175

176-
def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, sync_buffer, M_per_rank, N, signal_target, rank,
177-
group, local_world_size, world_size, local_copy_kernel, gemm_kernel, gemm_stream,
178-
ag_stream):
176+
def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, N, signal_target, local_rank,
177+
local_world_size, set_signal_kernel, gemm_kernel, gemm_stream, ag_stream):
179178

180179
with torch.cuda.stream(gemm_stream):
181-
local_copy_kernel(
182-
A, ag_buffer[rank], signal_buffer[rank], sync_buffer, stream=gemm_stream.cuda_stream)
180+
set_signal_kernel(signal_buffer[local_rank], stream=gemm_stream.cuda_stream)
183181

184182
ag_stream.wait_stream(gemm_stream)
185183

186-
cp_engine_producer_all_gather_full_mesh_pull(A, ag_buffer, signal_buffer, M_per_rank, N,
187-
signal_target, rank, local_world_size, world_size,
184+
cp_engine_producer_all_gather_full_mesh_pull(ag_buffer, signal_buffer, M_per_rank,
185+
signal_target, local_rank, local_world_size,
188186
ag_stream)
189187

190188
with torch.cuda.stream(gemm_stream):
191-
gemm_kernel(ag_buffer[rank], B, signal_buffer[rank], C, stream=gemm_stream.cuda_stream)
189+
gemm_kernel(
190+
ag_buffer[local_rank], B, signal_buffer[local_rank], C, stream=gemm_stream.cuda_stream)
192191

193192
gemm_stream.wait_stream(ag_stream)
194193
current_stream = torch.cuda.current_stream()
@@ -212,6 +211,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
212211
M = args.M if args else 8192
213212
N = args.N if args else 8192
214213
K = args.K if args else 8192
214+
persistent = args.persistent
215215
M_per_rank = M // num_local_ranks
216216
N_per_rank = N // num_local_ranks
217217

@@ -221,48 +221,45 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
221221
threads = 256
222222

223223
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
224+
assert rank == local_rank and num_ranks == num_local_ranks, "only support single node for now"
224225
allocator = tilelang.get_allocator(
225226
size=2**30,
226227
device="cuda",
227228
is_distributed=True,
228229
local_rank=local_rank,
229230
num_local_ranks=num_local_ranks,
230231
group=group)
231-
kernel = gemm_kernel(M, N, K, num_ranks, rank, BLOCK_M, BLOCK_N, BLOCK_K, threads)
232-
local_copy_kernel = copy_and_barrier_all_intra_node_kernel(
232+
gemm_func = gemm_kernel(M, N, K, local_rank, num_local_ranks, BLOCK_M, BLOCK_N, BLOCK_K,
233+
threads, persistent)
234+
set_signal_func = set_signal_kernel(
233235
local_rank=local_rank,
234-
rank=local_rank,
235-
num_ranks=num_ranks,
236-
M=M,
237-
K=K,
238-
block_M=64,
239-
block_K=64,
240-
threads=128,
236+
num_local_ranks=num_local_ranks,
237+
threads=32,
241238
)
242-
kernel.initialize(allocator=allocator)
243-
local_copy_kernel.initialize(allocator=allocator)
239+
gemm_func.initialize(allocator=allocator)
240+
set_signal_func.initialize(allocator=allocator)
244241
if local_rank == 1:
245-
print(kernel.get_kernel_source())
246-
print(local_copy_kernel.get_kernel_source())
242+
print(gemm_func.get_kernel_source())
243+
print(set_signal_func.get_kernel_source())
247244

248-
A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
249245
B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
250246
C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
251247
ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
248+
A = ag_buffer[local_rank][M_per_rank * local_rank:M_per_rank * (local_rank + 1), :].normal_()
252249
signal_buffer = tilelang.tensor((num_local_ranks,),
253250
torch.uint32,
254251
allocator=allocator,
255252
return_peers=True)
256-
signal_buffer[rank].fill_(0) # check if needed
257-
sync_buffer = tilelang.tensor((3 * num_ranks,), torch.uint32, allocator=allocator)
258253

259254
gemm_stream = torch.cuda.Stream()
260255
ag_stream = torch.cuda.Stream(priority=-1)
261256
signal_target = 1
262257

263-
tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, sync_buffer, M_per_rank, K,
264-
signal_target, rank, group, num_local_ranks, num_local_ranks,
265-
local_copy_kernel, kernel, gemm_stream, ag_stream)
258+
dist.barrier()
259+
260+
tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target,
261+
local_rank, num_local_ranks, set_signal_func, gemm_func, gemm_stream,
262+
ag_stream)
266263

267264
torch_ag_buffer = torch.empty([M, K], dtype=dtype, device="cuda")
268265
torch_C = torch_ag_gemm(group, A, B, torch_ag_buffer)
@@ -273,10 +270,10 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
273270
print(f"rank {local_rank} check failed.❌")
274271
print(f"torch_C: {torch_C}, tilelang_C: {tilelang_C}")
275272

276-
tl_out, tl_t = perf_fn(
277-
lambda: ag_gemm_op(A, B, C, ag_buffer, signal_buffer, sync_buffer, M_per_rank, K,
278-
signal_target, rank, group, num_local_ranks, num_local_ranks,
279-
local_copy_kernel, kernel, gemm_stream, ag_stream),
273+
_, tl_t = perf_fn(
274+
lambda:
275+
ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, local_rank,
276+
num_local_ranks, set_signal_func, gemm_func, gemm_stream, ag_stream),
280277
warmup=5,
281278
rep=10)
282279

@@ -294,6 +291,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
294291
parser.add_argument('--M', type=int, default=8192, help='M dimension')
295292
parser.add_argument('--N', type=int, default=28672, help='N dimension')
296293
parser.add_argument('--K', type=int, default=8192, help='K dimension')
294+
parser.add_argument('--persistent', action='store_true', help='Use persistent kernel')
297295
args = parser.parse_args()
298296
num_processes = args.num_processes
299297

0 commit comments

Comments
 (0)