Skip to content

Commit 9de9e1b

Browse files
committed
[Feature][Example] Support return_peers in _allocate_tensor and add ag_gemm_ipc example
1 parent 2fbbd76 commit 9de9e1b

File tree

5 files changed

+249
-7
lines changed

5 files changed

+249
-7
lines changed

examples/distributed/example_all_to_all.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import torch
22
import pynvshmem
3-
import os
43
import tilelang
54
import tilelang.language as T
65
from tilelang.profiler import TensorSupplyType
76
from tilelang.distributed.utils import init_distributed
87
import argparse
98
import random
109

11-
1210
tilelang.disable_cache()
1311

1412

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import os
2+
import tilelang
3+
import tilelang.language as T
4+
import argparse
5+
import torch
6+
import torch.distributed as dist
7+
import torch.multiprocessing
8+
from tilelang.distributed.utils import init_dist
9+
from cuda import cudart
10+
from tilelang.distributed.utils import set_signal, wait_eq
11+
12+
tilelang.disable_cache()
13+
os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log
14+
15+
16+
def gemm_kernel(M,
17+
N,
18+
K,
19+
num_rank,
20+
block_M,
21+
block_N,
22+
block_K,
23+
threads,
24+
dtype="float16",
25+
accum_dtype="float"):
26+
27+
@T.prim_func
28+
def main(
29+
A: T.Tensor((M, K), dtype),
30+
B: T.Tensor((K, N // num_rank), dtype),
31+
C: T.Tensor((M, N // num_rank), dtype),
32+
):
33+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
34+
A_shared = T.alloc_shared((block_M, block_K), dtype)
35+
B_shared = T.alloc_shared((block_K, block_N), dtype)
36+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
37+
38+
T.clear(C_local)
39+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
40+
T.copy(A[by * block_M, k * block_K], A_shared)
41+
T.copy(B[k * block_K, bx * block_N], B_shared)
42+
T.gemm(A_shared, B_shared, C_local)
43+
T.copy(C_local, C[by * block_M, bx * block_N])
44+
45+
return main
46+
47+
48+
def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N,
49+
signal_target, rank, local_world_size, world_size,
50+
intranode_ag_stream):
51+
local_rank = rank % local_world_size
52+
n_nodes = world_size // local_world_size
53+
node_rank = rank // local_world_size
54+
55+
for i in range(1, local_world_size):
56+
segment = rank * M_per_rank * N
57+
local_dst_rank = (local_rank + local_world_size - i) % local_world_size
58+
src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size()
59+
dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size()
60+
# Using copy engine to perform intranode transmission
61+
# Sending rank-th local tensor to other ranks inside the node.
62+
(err,) = cudart.cudaMemcpyAsync(
63+
dst_ptr,
64+
src_ptr,
65+
M_per_rank * N * local_tensor.element_size(),
66+
cudart.cudaMemcpyKind.cudaMemcpyDefault,
67+
intranode_ag_stream.cuda_stream,
68+
)
69+
# Notify the peer that the transmission is done.
70+
set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream)
71+
72+
for i in range(1, n_nodes):
73+
recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size
74+
recv_segment = recv_rank * M_per_rank * N
75+
# Waiting for the internode data ready
76+
wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream)
77+
src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size()
78+
for j in range(1, local_world_size):
79+
local_dst_rank = (local_rank + local_world_size - j) % local_world_size
80+
dst_ptr = ag_buffer[local_dst_rank].data_ptr(
81+
) + recv_segment * local_tensor.element_size()
82+
# Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node.
83+
(err,) = cudart.cudaMemcpyAsync(
84+
dst_ptr,
85+
src_ptr,
86+
M_per_rank * N * local_tensor.element_size(),
87+
cudart.cudaMemcpyKind.cudaMemcpyDefault,
88+
intranode_ag_stream.cuda_stream,
89+
)
90+
# Notify the peer that the transmission is done.
91+
set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream)
92+
93+
94+
def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank, group,
95+
local_world_size, world_size, gemm_kernel, ag_stream):
96+
97+
dist.barrier(group)
98+
99+
# all_gather A to ag_buffer
100+
with torch.cuda.stream(ag_stream):
101+
cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target,
102+
rank, local_world_size, world_size, ag_stream)
103+
104+
current_stream = torch.cuda.current_stream()
105+
current_stream.wait_stream(ag_stream)
106+
107+
dist.barrier(group)
108+
torch.cuda.synchronize()
109+
110+
torch.cuda.synchronize()
111+
torch.distributed.barrier(group)
112+
gemm_kernel(ag_buffer[rank], B, C)
113+
torch.cuda.synchronize()
114+
torch.distributed.barrier(group)
115+
116+
return C
117+
118+
119+
def torch_ag_gemm(
120+
pg: torch.distributed.ProcessGroup,
121+
local_input: torch.Tensor,
122+
local_weight: torch.Tensor,
123+
ag_out: torch.Tensor,
124+
):
125+
torch.distributed.all_gather_into_tensor(ag_out, local_input, pg)
126+
ag_gemm_output = torch.matmul(ag_out, local_weight)
127+
return ag_gemm_output
128+
129+
130+
def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
131+
dtype = torch.float16
132+
M = args.M if args else 8192
133+
N = args.N if args else 8192
134+
K = args.K if args else 8192
135+
M_per_rank = M // num_local_ranks
136+
N_per_rank = N // num_local_ranks
137+
138+
BLOCK_M = 128
139+
BLOCK_N = 128
140+
BLOCK_K = 64
141+
threads = 256
142+
assert num_local_ranks == 2, "this example only supports 2 ranks copying to each other"
143+
144+
rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
145+
allocator = tilelang.get_allocator(
146+
size=2**30,
147+
device="cuda",
148+
is_distributed=True,
149+
local_rank=local_rank,
150+
num_local_ranks=num_local_ranks,
151+
group=group)
152+
kernel = tilelang.compile(gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads))
153+
kernel.initialize(allocator=allocator)
154+
if local_rank == 0:
155+
print(kernel.get_kernel_source())
156+
157+
A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_()
158+
B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_()
159+
C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator)
160+
ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True)
161+
signal_buffer = tilelang.tensor((num_local_ranks,),
162+
torch.int32,
163+
allocator=allocator,
164+
return_peers=True)
165+
signal_buffer[rank].fill_(0)
166+
ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A)
167+
168+
dist.barrier(group)
169+
170+
ag_stream = torch.cuda.Stream()
171+
signal_target = 1
172+
173+
tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank,
174+
group, num_local_ranks, num_local_ranks, kernel, ag_stream)
175+
176+
torch_ag_buffer = torch.empty([M, K], dtype=dtype, device="cuda")
177+
torch_C = torch_ag_gemm(group, A, B, torch_ag_buffer)
178+
179+
if torch.allclose(torch_C, tilelang_C, atol=1e-6, rtol=1e-6):
180+
print(f"rank {local_rank} check passed.✅")
181+
else:
182+
print(f"rank {local_rank} check failed.❌")
183+
print(f"torch_C: {torch_C}, tilelang_C: {tilelang_C}")
184+
raise ValueError("Test failed")
185+
186+
dist.destroy_process_group()
187+
188+
189+
if __name__ == "__main__":
190+
parser = argparse.ArgumentParser()
191+
parser.add_argument(
192+
'--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)')
193+
parser.add_argument('--M', type=int, default=8192, help='M dimension')
194+
parser.add_argument('--N', type=int, default=8192, help='N dimension')
195+
parser.add_argument('--K', type=int, default=8192, help='K dimension')
196+
args = parser.parse_args()
197+
num_processes = args.num_processes
198+
199+
torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes)

tilelang/distributed/utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import datetime
44
import os
55
import inspect
6-
from typing import List, Union, Tuple, Callable, Sequence
6+
from typing import List, Union, Tuple, Callable, Sequence, Optional
77
from contextlib import contextmanager
88

99
import importlib.metadata
@@ -263,3 +263,36 @@ def supports_p2p_native_atomic():
263263
cudart.cudaDeviceP2PAttr.cudaDevP2PAttrNativeAtomicSupported, 0, 1)
264264
CUDA_CHECK(err)
265265
return support == 1
266+
267+
268+
def set_signal(signal_tensor: torch.Tensor,
269+
signal: int,
270+
stream: Optional[torch.cuda.Stream] = None):
271+
stream = stream or torch.cuda.current_stream()
272+
if signal_tensor.dtype == torch.int32:
273+
(err,) = cuda.cuStreamWriteValue32(
274+
stream.cuda_stream,
275+
signal_tensor.data_ptr(),
276+
signal,
277+
cuda.CUstreamWriteValue_flags.CU_STREAM_WRITE_VALUE_DEFAULT,
278+
)
279+
CUDA_CHECK(err)
280+
else:
281+
raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")
282+
283+
284+
def wait_eq(signal_tensor: torch.Tensor,
285+
signal: int,
286+
stream: Optional[torch.cuda.Stream] = None,
287+
require_i64=False):
288+
stream = stream or torch.cuda.current_stream()
289+
if signal_tensor.dtype == torch.int32:
290+
(err,) = cuda.cuStreamWaitValue32(
291+
stream.cuda_stream,
292+
signal_tensor.data_ptr(),
293+
signal,
294+
cuda.CUstreamWaitValue_flags.CU_STREAM_WAIT_VALUE_EQ,
295+
)
296+
CUDA_CHECK(err)
297+
else:
298+
raise Exception(f"Unsupported signal dtype {signal_tensor.dtype}")

tilelang/utils/allocator.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _alloc(self):
125125
if rc != 0:
126126
msg = _libcudart.cudaGetErrorString(rc)
127127
raise RuntimeError(f"cudaMalloc failed: {rc} {msg.decode() if msg else ''}")
128-
self._ptr = self._base_ptr
128+
self._ptr.value = self._base_ptr.value
129129

130130
def _free(self):
131131
if getattr(self, "_base_ptr", None) and self._base_ptr.value:
@@ -166,6 +166,7 @@ def initialized(self) -> bool:
166166
def _allocate_tensor(self,
167167
shape: Tuple[int, ...],
168168
dtype: torch.dtype,
169+
return_peers=False,
169170
take_ownership: bool = False) -> torch.Tensor:
170171

171172
numel = _prod_shape(shape)
@@ -198,13 +199,23 @@ def _allocate_tensor(self,
198199

199200
t = tensor_from_ptr(cur_ptr_val, shape, dtype_str, self._device, take_ownership)
200201

202+
if return_peers:
203+
peer_ts = []
204+
for i in range(self._group.size()):
205+
if i == self._local_rank:
206+
peer_ts.append(t)
207+
else:
208+
peer_ptr_val = int(self._buffer_ptrs[i]) + current_offset
209+
peer_t = tensor_from_ptr(peer_ptr_val, shape, dtype_str, self._device, False)
210+
peer_ts.append(peer_t)
211+
201212
if take_ownership:
202213
self._ptr = ctypes.c_void_p(0)
203214
else:
204215
new_ptr_val = cur_ptr_val + bytes_alloc
205216
self._ptr.value = new_ptr_val
206217

207-
return t
218+
return peer_ts if return_peers else t
208219

209220
@property
210221
def ptr(self) -> int:

tilelang/utils/tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,15 @@ def map_torch_type(intype: str) -> torch.dtype:
4545
def tensor(shape: Tuple[int, ...],
4646
dtype: torch.dtype,
4747
device: Optional[Union[str, torch.device, int]] = None,
48-
allocator: Optional[BaseAllocator] = None) -> torch.Tensor:
48+
allocator: Optional[BaseAllocator] = None,
49+
return_peers: Optional[bool] = None) -> Union[torch.Tensor, list[torch.Tensor]]:
4950
if allocator is not None:
5051
assert allocator.initialized(), "Allocator is not initialized"
5152
if device is not None:
5253
device = parse_device(device)
5354
assert allocator.device == device, f"Allocator device must be the " \
5455
f"same as the device of the tensor, but got {allocator.device} != {device}"
55-
return allocator._allocate_tensor(shape, dtype)
56+
return allocator._allocate_tensor(shape, dtype, return_peers)
5657
else:
5758
return torch.empty(shape, dtype=dtype, device=device)
5859

0 commit comments

Comments
 (0)