Skip to content

Commit a499425

Browse files
committed
[Dev] Update cannon example, with non-specialize and specialize implementations
1 parent 2ee56e6 commit a499425

File tree

1 file changed

+229
-54
lines changed

1 file changed

+229
-54
lines changed

examples/distributed/example_cannon.py

Lines changed: 229 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
tilelang.disable_cache()
1414

1515

16-
def cannon(MESH, M, N, K, block_M, block_N, block_K, dtype="float16"):
16+
def cannon(MESH, M, N, K, block_M, block_N, block_K, dtype="float16", specialize=False):
1717

1818
M_local = T.ceildiv(M, MESH)
1919
N_local = T.ceildiv(N, MESH)
2020
K_local = T.ceildiv(K, MESH)
2121
accum_dtype = "float32"
2222

23+
sm_num = 132 # 132 SMs for H100
24+
total_tiles = T.ceildiv(M_local, block_M) * T.ceildiv(N_local, block_N)
25+
2326
@T.prim_func
2427
def main(
2528
A: T.Tensor((2, M_local, K_local), dtype),
@@ -30,8 +33,11 @@ def main(
3033
B_signal_from: T.Tensor((T.ceildiv(N, block_N),), "uint64"),
3134
C: T.Tensor((M_local, N_local), dtype),
3235
):
33-
with T.Kernel(
34-
T.ceildiv(M_local, block_M), T.ceildiv(N_local, block_N), threads=128) as (bx, by):
36+
grid_size = T.min(sm_num, total_tiles)
37+
A_rows_per_block = T.ceildiv(M_local, grid_size)
38+
B_cols_per_block = T.ceildiv(N_local, grid_size)
39+
waves = T.ceildiv(total_tiles, sm_num)
40+
with T.Kernel(grid_size, threads=256) as (block_id):
3541
mype = T.alloc_local([1], "int32")
3642
npes = T.alloc_local([1], "int32")
3743
a_peer_from = T.alloc_local([1], "int32")
@@ -54,71 +60,180 @@ def main(
5460
for ko in T.serial(MESH):
5561
if tx == 0:
5662
T.signal_wait_until(
57-
T.address_of(A_signal_from[bx]),
58-
T.NVSHMEM_CMP_EQ,
59-
T.ceildiv(N_local, block_N) * ko,
63+
T.address_of(A_signal_from[0]),
64+
T.NVSHMEM_CMP_GE,
65+
total_tiles * ko,
6066
)
6167
T.signal_wait_until(
62-
T.address_of(B_signal_from[by]),
63-
T.NVSHMEM_CMP_EQ,
64-
T.ceildiv(M_local, block_M) * ko,
68+
T.address_of(B_signal_from[0]),
69+
T.NVSHMEM_CMP_GE,
70+
total_tiles * ko,
6571
)
6672

67-
if by == 0:
73+
if block_id < T.ceildiv(M_local, A_rows_per_block):
6874
T.putmem_signal_nbi_block(
69-
T.address_of(A[(ko + 1) % 2, bx * block_M, 0]),
70-
T.address_of(A[ko % 2, bx * block_M,
71-
0]), block_M * K_local * dsize_map[dtype],
72-
T.address_of(A_signal_to[bx]), ko + 1, T.NVSHMEM_SIGNAL_SET, a_peer_to[0])
73-
if bx == 0:
75+
T.address_of(A[(ko + 1) % 2, A_rows_per_block * block_id, 0]),
76+
T.address_of(A[ko % 2, A_rows_per_block * block_id,
77+
0]), A_rows_per_block * K_local * dsize_map[dtype],
78+
T.address_of(A_signal_to[0]), 1, T.NVSHMEM_SIGNAL_ADD, a_peer_to[0])
79+
if block_id < T.ceildiv(N_local, B_cols_per_block):
7480
T.putmem_signal_nbi_block(
75-
T.address_of(B[(ko + 1) % 2, by * block_N, 0]),
76-
T.address_of(B[ko % 2, by * block_N,
77-
0]), block_N * K_local * dsize_map[dtype],
78-
T.address_of(B_signal_to[by]), ko + 1, T.NVSHMEM_SIGNAL_SET, b_peer_to[0])
79-
80-
for ki in T.Pipelined(T.ceildiv(K_local, block_K)):
81-
T.copy(
82-
A[ko % 2, bx * block_M:(bx + 1) * block_M, ki * block_K:(ki + 1) * block_K],
83-
A_shared)
84-
T.copy(
85-
B[ko % 2, by * block_N:(by + 1) * block_N, ki * block_K:(ki + 1) * block_K],
86-
B_shared)
87-
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
81+
T.address_of(B[(ko + 1) % 2, B_cols_per_block * block_id, 0]),
82+
T.address_of(B[ko % 2, B_cols_per_block * block_id,
83+
0]), B_cols_per_block * K_local * dsize_map[dtype],
84+
T.address_of(B_signal_to[0]), 1, T.NVSHMEM_SIGNAL_ADD, b_peer_to[0])
85+
86+
for w in T.serial(waves):
87+
88+
bx = (grid_size * w + block_id) // T.ceildiv(N_local, block_N)
89+
by = (grid_size * w + block_id) % T.ceildiv(N_local, block_N)
90+
91+
if bx < T.ceildiv(M_local, block_M) and by < T.ceildiv(N_local, block_N):
92+
T.copy(C[bx * block_M, by * block_N], C_local)
93+
for ki in T.Pipelined(T.ceildiv(K_local, block_K), num_stages=4):
94+
T.copy(A[ko % 2, bx * block_M, ki * block_K], A_shared)
95+
T.copy(B[ko % 2, by * block_N, ki * block_K], B_shared)
96+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
97+
98+
T.copy(C_local, C[bx * block_M, by * block_N])
99+
if tx == 0:
100+
T.signal_op(
101+
T.address_of(A_signal_from[0]),
102+
1,
103+
T.NVSHMEM_SIGNAL_ADD,
104+
a_peer_from[0],
105+
)
106+
T.signal_op(
107+
T.address_of(B_signal_from[0]),
108+
1,
109+
T.NVSHMEM_SIGNAL_ADD,
110+
b_peer_from[0],
111+
)
112+
113+
# TODO: check if __syncthreads() is needed
114+
T.signal_wait_until(
115+
T.address_of(A_signal_to[0]),
116+
T.NVSHMEM_CMP_GE,
117+
(ko + 1) * T.ceildiv(M_local, A_rows_per_block),
118+
)
119+
T.signal_wait_until(
120+
T.address_of(B_signal_to[0]),
121+
T.NVSHMEM_CMP_GE,
122+
(ko + 1) * T.ceildiv(N_local, B_cols_per_block),
123+
)
124+
125+
# TODO: fix correctness
126+
@T.prim_func
127+
def main_specialize(
128+
A: T.Tensor((2, M_local, K_local), dtype),
129+
B: T.Tensor((2, N_local, K_local), dtype),
130+
A_signal_to: T.Tensor((T.ceildiv(M, block_M),), "uint64"),
131+
A_signal_from: T.Tensor((T.ceildiv(M, block_M),), "uint64"),
132+
B_signal_to: T.Tensor((T.ceildiv(N, block_N),), "uint64"),
133+
B_signal_from: T.Tensor((T.ceildiv(N, block_N),), "uint64"),
134+
C: T.Tensor((M_local, N_local), dtype),
135+
):
136+
# 0-compute blocks: compute
137+
# compute_blocks-grid_size: copy
138+
copy_blocks = 20
139+
compute_blocks = T.min(sm_num - copy_blocks, total_tiles)
140+
grid_size = copy_blocks + compute_blocks
141+
A_rows_per_block = T.ceildiv(M_local, copy_blocks)
142+
B_cols_per_block = T.ceildiv(N_local, copy_blocks)
143+
waves = T.ceildiv(total_tiles, compute_blocks)
144+
with T.Kernel(grid_size, threads=256) as (block_id):
145+
mype = T.alloc_local([1], "int32")
146+
npes = T.alloc_local([1], "int32")
147+
a_peer_from = T.alloc_local([1], "int32")
148+
a_peer_to = T.alloc_local([1], "int32")
149+
b_peer_from = T.alloc_local([1], "int32")
150+
b_peer_to = T.alloc_local([1], "int32")
151+
mype[0] = T.get_pe()
152+
npes[0] = T.get_pe_num()
153+
154+
A_shared = T.alloc_shared((block_M, block_K), dtype)
155+
B_shared = T.alloc_shared((block_N, block_K), dtype)
156+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
157+
158+
tx = T.get_thread_binding(0)
159+
a_peer_from[0] = (mype[0] + 1) % MESH + MESH * (mype[0] // MESH)
160+
a_peer_to[0] = (mype[0] - 1 + MESH) % MESH + MESH * (mype[0] // MESH)
161+
b_peer_from[0] = (mype[0] + MESH) % npes[0]
162+
b_peer_to[0] = (mype[0] - MESH + npes[0]) % npes[0]
163+
T.clear(C_local)
164+
for ko in T.serial(MESH):
165+
if block_id >= compute_blocks:
166+
if tx == 0:
167+
T.signal_wait_until(
168+
T.address_of(A_signal_from[0]),
169+
T.NVSHMEM_CMP_GE,
170+
total_tiles * ko,
171+
)
172+
T.signal_wait_until(
173+
T.address_of(B_signal_from[0]),
174+
T.NVSHMEM_CMP_GE,
175+
total_tiles * ko,
176+
)
177+
T.putmem_signal_nbi_block(
178+
T.address_of(A[(ko + 1) % 2, A_rows_per_block * (block_id - compute_blocks),
179+
0]),
180+
T.address_of(A[ko % 2, A_rows_per_block * (block_id - compute_blocks),
181+
0]), A_rows_per_block * K_local * dsize_map[dtype],
182+
T.address_of(A_signal_to[0]), 1, T.NVSHMEM_SIGNAL_ADD, a_peer_to[0])
183+
T.putmem_signal_nbi_block(
184+
T.address_of(B[(ko + 1) % 2, B_cols_per_block * (block_id - compute_blocks),
185+
0]),
186+
T.address_of(B[ko % 2, B_cols_per_block * (block_id - compute_blocks),
187+
0]), B_cols_per_block * K_local * dsize_map[dtype],
188+
T.address_of(B_signal_to[0]), 1, T.NVSHMEM_SIGNAL_ADD, b_peer_to[0])
189+
190+
if block_id < compute_blocks:
191+
for w in T.serial(waves):
192+
193+
bx = (compute_blocks * w + block_id) // T.ceildiv(N_local, block_N)
194+
by = (compute_blocks * w + block_id) % T.ceildiv(N_local, block_N)
195+
196+
if bx < T.ceildiv(M_local, block_M) and by < T.ceildiv(N_local, block_N):
197+
T.copy(C[bx * block_M, by * block_N], C_local)
198+
for ki in T.Pipelined(T.ceildiv(K_local, block_K), num_stages=4):
199+
T.copy(A[ko % 2, bx * block_M, ki * block_K], A_shared)
200+
T.copy(B[ko % 2, by * block_N, ki * block_K], B_shared)
201+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
202+
203+
T.copy(C_local, C[bx * block_M, by * block_N])
204+
if tx == 0:
205+
T.signal_op(
206+
T.address_of(A_signal_from[0]),
207+
1,
208+
T.NVSHMEM_SIGNAL_ADD,
209+
a_peer_from[0],
210+
)
211+
T.signal_op(
212+
T.address_of(B_signal_from[0]),
213+
1,
214+
T.NVSHMEM_SIGNAL_ADD,
215+
b_peer_from[0],
216+
)
88217

89-
if tx == 0:
90218
T.signal_wait_until(
91-
T.address_of(A_signal_to[bx]),
92-
T.NVSHMEM_CMP_EQ,
93-
ko + 1,
219+
T.address_of(A_signal_to[0]),
220+
T.NVSHMEM_CMP_GE,
221+
(ko + 1) * copy_blocks,
94222
)
95223
T.signal_wait_until(
96-
T.address_of(B_signal_to[by]),
97-
T.NVSHMEM_CMP_EQ,
98-
ko + 1,
99-
)
100-
T.signal_op(
101-
T.address_of(A_signal_from[bx]),
102-
1,
103-
T.NVSHMEM_SIGNAL_ADD,
104-
a_peer_from[0],
224+
T.address_of(B_signal_to[0]),
225+
T.NVSHMEM_CMP_GE,
226+
(ko + 1) * copy_blocks,
105227
)
106-
T.signal_op(
107-
T.address_of(B_signal_from[by]),
108-
1,
109-
T.NVSHMEM_SIGNAL_ADD,
110-
b_peer_from[0],
111-
)
112-
T.copy(C_local, C[bx * block_M:(bx + 1) * block_M, by * block_N:(by + 1) * block_N])
113228

114-
return main
229+
return main_specialize if specialize else main
115230

116231

117232
def parse_args():
118233
parser = argparse.ArgumentParser()
119-
parser.add_argument("--M", default=256, type=int)
120-
parser.add_argument("--N", default=256, type=int)
121-
parser.add_argument("--K", default=256, type=int)
234+
parser.add_argument("--M", default=16384, type=int)
235+
parser.add_argument("--N", default=16384, type=int)
236+
parser.add_argument("--K", default=16384, type=int)
122237
parser.add_argument("--warmup", default=20, type=int, help="warmup iterations")
123238
parser.add_argument("--iters", default=100, type=int, help="perf iterations")
124239
parser.add_argument("--dtype", default="float16", type=str, help="data type")
@@ -135,14 +250,15 @@ def parse_args():
135250
assert MESH * MESH == WORLD_SIZE, "Mesh size must match world size"
136251

137252
M, N, K = args.M, args.N, args.K
138-
block_M, block_N, block_K = 64, 64, 64
253+
specialize = False
254+
block_M, block_N, block_K = 128, 256, 64
139255
dtype = dtype_map[args.dtype]
140256

141257
M_local = math.ceil(M / MESH)
142258
N_local = math.ceil(N / MESH)
143259
K_local = math.ceil(K / MESH)
144260

145-
func = cannon(MESH, M, N, K, block_M, block_N, block_K, args.dtype)
261+
func = cannon(MESH, M, N, K, block_M, block_N, block_K, args.dtype, specialize)
146262
kernel = tilelang.compile(
147263
func, pass_configs={
148264
"tl.disable_tma_lower": True,
@@ -210,8 +326,67 @@ def parse_args():
210326
print('-' * 100)
211327
print(f"[Rank {RANK}] ✅ Tilelang and Torch match")
212328
else:
329+
abs_error = torch.abs(C_tilelang - ref)
330+
rel_error = abs_error / (torch.abs(ref) + 1e-8)
331+
332+
max_abs_error = abs_error.max().item()
333+
max_rel_error = rel_error.max().item()
334+
mismatch_ratio = (abs_error > (1e-2 + 1e-2 * torch.abs(ref))).float().mean().item()
335+
213336
print('-' * 100)
214337
print(f"[Rank {RANK}] ❌ Tilelang and Torch mismatch")
215338
print(f"[Rank {RANK}] ref:\n{ref}")
216339
print(f"[Rank {RANK}] tilelang:\n{C_tilelang}")
340+
print(f"[Rank {RANK}] Mismatch ratio: {mismatch_ratio:.4f}")
341+
print(f"[Rank {RANK}] Max absolute error: {max_abs_error:.6f}")
342+
print(f"[Rank {RANK}] Max relative error: {max_rel_error:.6f}")
217343
dist.barrier()
344+
345+
346+
def bench(func, *args):
347+
bench_iters = 10
348+
torch.cuda._sleep(1000000000)
349+
350+
def preprocess():
351+
# clear signals
352+
args[2].fill_(0)
353+
args[3].fill_(0)
354+
args[4].fill_(0)
355+
args[5].fill_(0)
356+
357+
# warmup
358+
for _ in range(20):
359+
preprocess()
360+
_ = func(*args)
361+
362+
st = torch.cuda.Event(enable_timing=True)
363+
ed = torch.cuda.Event(enable_timing=True)
364+
# bench
365+
st.record()
366+
for _ in range(bench_iters):
367+
preprocess()
368+
_ = func(*args)
369+
ed.record()
370+
torch.cuda.synchronize()
371+
avg_time = st.elapsed_time(ed) / bench_iters
372+
373+
return avg_time
374+
375+
376+
def reduce_local_time(local_time):
377+
tensor = torch.tensor([local_time], dtype=torch.float32).to("cuda")
378+
dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)
379+
if dist.get_rank() == 0:
380+
world_size = dist.get_world_size()
381+
mean_time = (tensor / world_size).item()
382+
return mean_time
383+
return None
384+
385+
386+
total_flops = 2 * M * N * K
387+
avg_time = reduce_local_time(
388+
bench(kernel, A, B, A_signal_to, A_signal_from, B_signal_to, B_signal_from, C_tilelang))
389+
390+
if RANK == 0:
391+
print(f"avg time of RANK {RANK}: {avg_time} ms")
392+
print(f"TFlops: {total_flops / avg_time * 1e-9} TFlops")

0 commit comments

Comments
 (0)