Skip to content

Commit d8b26c7

Browse files
committed
[Feat]support internode copy with intranode copy
1 parent 2cc1a2c commit d8b26c7

File tree

5 files changed

+157
-7
lines changed

5 files changed

+157
-7
lines changed
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import argparse
2+
import torch
3+
import torch.distributed as dist
4+
import pynvshmem
5+
import tilelang
6+
import tilelang.language as T
7+
import os
8+
from tilelang.distributed.utils import init_distributed, dtype_map, perf_fn
9+
from tilelang.distributed.utils import init_dist
10+
from tilelang.env import env
11+
from packaging import version
12+
import importlib.metadata
13+
cuda_python_version = importlib.metadata.version("cuda-python")
14+
if version.parse(cuda_python_version) >= version.parse("12.8.0"):
15+
from cuda.bindings import driver as cuda
16+
from cuda.bindings import runtime as cudart
17+
else:
18+
from cuda import cuda, cudart
19+
# NODES=2 NODE_RANK=0 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
20+
# NODES=2 NODE_RANK=1 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
21+
22+
def internode_gather(M, local_world_size, block_M, threads):
23+
24+
@T.prim_func
25+
def main(
26+
dst: T.Tensor((M), "float32"),
27+
src: T.Tensor((M), "float32"),
28+
):
29+
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx):
30+
rank = T.alloc_local([1], "uint64")
31+
rank[0] = (T.get_pe()+local_world_size)%(2*local_world_size) # 2 nodes
32+
T.putmem_nbi_block(
33+
T.address_of(dst[bx * block_M]), T.address_of(src[bx * block_M]),
34+
block_M *4 , rank[0])
35+
36+
return main
37+
38+
def intranode_gather(M, world_size, block_M, threads):
39+
40+
@T.prim_func
41+
def main(
42+
dst: T.Tensor((M*world_size), "float32"),
43+
src: T.Tensor((M*2), "float32"),
44+
):
45+
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx):
46+
rank = T.alloc_local([1], "uint64")
47+
num_rank = T.alloc_local([1], "uint64")
48+
rank[0] = T.get_rank()
49+
num_rank[0] = T.get_num_ranks()
50+
tid = T.get_thread_binding()
51+
if tid == 0:
52+
T.print(T.cast(rank[0],"int32"),msg="signal")
53+
T.print(T.cast(num_rank[0],"int32"),msg="signal")
54+
for k in T.serial(world_size//2): # 2 node
55+
T.put_block(
56+
src=T.address_of(src[bx * block_M]),
57+
dst=T.address_of(dst[bx * block_M + rank[0]*M]),
58+
size=block_M,
59+
dst_pe=k,
60+
)
61+
T.put_block(
62+
src=T.address_of(src[bx * block_M + M]),
63+
dst=T.address_of(dst[bx * block_M + M*num_rank[0] + rank[0]*M]),
64+
size=block_M,
65+
dst_pe=k,
66+
)
67+
68+
return main
69+
70+
if __name__ == '__main__':
71+
tilelang.disable_cache()
72+
73+
M = 2
74+
K = 12288
75+
#for 2 node(16 GPUs), world_size=16,rank is 0-15,local rank is 0-7
76+
WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed(return_tp_group=True,return_lc_group=True)
77+
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
78+
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
79+
80+
allocator = tilelang.get_allocator(
81+
size=2**25,
82+
device="cuda",
83+
is_distributed=True,
84+
local_rank=LOCAL_RANK,
85+
num_local_ranks=local_world_size,
86+
group=LC_GROUP)
87+
print(local_world_size,LOCAL_RANK)
88+
89+
# Each rank sends the local_tensor to ranks of other nodes with the same local_rank
90+
# Assuming there are 2 nodes, each with 4 workers
91+
# 0-th local tensor ([0] -> [4]), 4-th local tensor ([4] -> [0])
92+
# 1-th local tensor ([1] -> [5]), 5-th local tensor ([5] -> [1])
93+
# 2-th local tensor ([2] -> [6]), 6-th local tensor ([6] -> [2])
94+
# 3-th local tensor ([3] -> [7]), 7-th local tensor ([7] -> [3])
95+
interkernel = tilelang.compile(internode_gather(M, local_world_size, M, 128))
96+
if LOCAL_RANK==0:
97+
print(interkernel.get_kernel_source())
98+
src = pynvshmem.nvshmem_create_tensor(
99+
[M], torch.float32)
100+
dst = pynvshmem.nvshmem_create_tensor(
101+
[M], torch.float32)
102+
input_data = torch.ones([M], dtype=torch.float32, device='cuda') * RANK
103+
src.copy_(input_data)
104+
105+
pynvshmem.nvshmem_barrier_all()
106+
dist.barrier(TP_GROUP)
107+
interkernel(dst, src)
108+
pynvshmem.nvshmem_barrier_all()
109+
110+
# Each rank sends the local_tensor and the received internode tensors to intranode ranks.
111+
# 0-th and 4-th local tensors ([0]->[1,2,3])
112+
# 1-th and 5-th local tensors ([1]->[0,2,3])
113+
# 2-th and 6-th local tensors ([2]->[0,1,3])
114+
# 3-th and 7-th local tensors ([3]->[0,1,2])
115+
# 0-th and 4-th local tensors ([4]->[5,6,7])
116+
# 1-th and 5-th local tensors ([5]->[4,6,7])
117+
# 2-th and 6-th local tensors ([6]->[4,5,7])
118+
# 3-th and 7-th local tensors ([7]->[4,5,6])
119+
src_intra = tilelang.tensor((M*2), torch.float32, allocator=allocator).normal_()
120+
dst_intra = tilelang.tensor((M*WORLD_SIZE), torch.float32, allocator=allocator)
121+
if RANK<WORLD_SIZE/2:
122+
cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
123+
cudart.cudaMemcpy(src_intra.data_ptr()+M*4, dst.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
124+
else:
125+
cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
126+
cudart.cudaMemcpy(src_intra.data_ptr()+M*4, src.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
127+
128+
env.USE_NVSHMEM=False
129+
intrakernel = tilelang.compile(intranode_gather(M, WORLD_SIZE, M, 128),pass_configs={tilelang.PassConfigKey.TL_DISABLE_RDC: True})
130+
intrakernel.initialize(allocator=allocator)
131+
if LOCAL_RANK==0:
132+
print(intrakernel.get_kernel_source())
133+
torch.cuda.synchronize()
134+
torch.distributed.barrier(LC_GROUP)
135+
intrakernel(dst_intra, src_intra)
136+
torch.cuda.synchronize()
137+
torch.distributed.barrier(LC_GROUP)
138+
139+
print(dst_intra)

src/target/codegen_cuda.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ std::string CodeGenTileLangCUDA::Finish() {
181181
}
182182

183183
if (use_nvshmem_) {
184-
decl_stream << "#include <nvshmem.h>>\n";
185-
decl_stream << "#include <nvshmemx.h>>\n";
184+
decl_stream << "#include <nvshmem.h>\n";
185+
decl_stream << "#include <nvshmemx.h>\n";
186186
}
187187

188188
if (need_cooperative_groups_) {

tilelang/distributed/launch.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@ nproc_per_node=${GPUS:=$(nvidia-smi --list-gpus | wc -l)} # set env var. `GPUS`
2020
nnodes=${NODES:=1} # set env var. `NODES` to # of nodes
2121
node_rank=${NODE_RANK:=0} # set env var. `NODE_RANK` to the rank of current node
2222

23-
master_addr="127.0.0.1"
24-
master_port="$(( RANDOM % 100 + 23400 ))" # randomly choose a port between 23400 and 23499
23+
master_addr=${ARNOLD_WORKER_0_HOST:="127.0.0.1"}
24+
if [ -z ${ARNOLD_WORKER_0_PORT} ]; then
25+
master_port="8361"
26+
else
27+
master_port=$(echo "$ARNOLD_WORKER_0_PORT" | cut -d "," -f 1)
28+
fi
2529
additional_args="--rdzv_endpoint=${master_addr}:${master_port}"
2630
IB_HCA=mlx5
2731

tilelang/distributed/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,14 @@ def init_dist(local_rank: int, num_local_ranks: int):
5656
list(range(num_local_ranks * num_nodes)))
5757

5858

59-
def init_distributed(return_tp_group=False, init_nvshmem=True):
59+
def init_distributed(return_tp_group=False, init_nvshmem=True, return_lc_group=False):
6060
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
6161
RANK = int(os.environ.get("RANK", 0))
6262
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
6363

6464
torch.distributed.init_process_group(
6565
backend="nccl",
66+
device_id=torch.device(f'cuda:{LOCAL_RANK}'),
6667
world_size=WORLD_SIZE,
6768
rank=RANK,
6869
timeout=datetime.timedelta(seconds=1800),
@@ -76,7 +77,13 @@ def init_distributed(return_tp_group=False, init_nvshmem=True):
7677
import pynvshmem
7778
pynvshmem.init_nvshmem_by_uniqueid(TP_GROUP)
7879

79-
if return_tp_group:
80+
if return_lc_group:
81+
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
82+
base = (RANK // local_world_size) * local_world_size
83+
LC_GROUP = torch.distributed.new_group(list(range(base, base + local_world_size)), backend="nccl")
84+
print(local_world_size,LC_GROUP,TP_GROUP)
85+
return WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP
86+
elif return_tp_group:
8087
return WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP
8188
else:
8289
return WORLD_SIZE, RANK, LOCAL_RANK

tilelang/utils/allocator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def _init_table(self):
149149
] * self._group.size()
150150
local_ipc_handle = _create_ipc_handle(self._base_ptr.value)
151151
dist.all_gather_object(ipc_handles, local_ipc_handle, self._group)
152-
buffer_ptrs = torch.empty(self._group.size(), dtype=torch.uint64)
152+
buffer_ptrs = torch.empty(self._group.size(), dtype=torch.uint64, device='cuda')
153153
_sync_ipc_handles(self._local_rank, device_ids,
154154
ctypes.c_void_p(buffer_ptrs.data_ptr()).value, ipc_handles, None)
155155
buffer_ptrs[self._local_rank] = self._base_ptr.value

0 commit comments

Comments
 (0)