Skip to content

Commit 1e0e208

Browse files
committed
[Feat]format code
1 parent d8b26c7 commit 1e0e208

File tree

2 files changed

+44
-34
lines changed

2 files changed

+44
-34
lines changed

examples/distributed/example_overlapping_allgather.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tilelang.env import env
1111
from packaging import version
1212
import importlib.metadata
13+
1314
cuda_python_version = importlib.metadata.version("cuda-python")
1415
if version.parse(cuda_python_version) >= version.parse("12.8.0"):
1516
from cuda.bindings import driver as cuda
@@ -19,6 +20,7 @@
1920
# NODES=2 NODE_RANK=0 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
2021
# NODES=2 NODE_RANK=1 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
2122

23+
2224
def internode_gather(M, local_world_size, block_M, threads):
2325

2426
@T.prim_func
@@ -28,19 +30,20 @@ def main(
2830
):
2931
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx):
3032
rank = T.alloc_local([1], "uint64")
31-
rank[0] = (T.get_pe()+local_world_size)%(2*local_world_size) # 2 nodes
33+
rank[0] = (T.get_pe() + local_world_size) % (2 * local_world_size) # 2 nodes
3234
T.putmem_nbi_block(
33-
T.address_of(dst[bx * block_M]), T.address_of(src[bx * block_M]),
34-
block_M *4 , rank[0])
35+
T.address_of(dst[bx * block_M]), T.address_of(src[bx * block_M]), block_M * 4,
36+
rank[0])
3537

3638
return main
3739

40+
3841
def intranode_gather(M, world_size, block_M, threads):
3942

4043
@T.prim_func
4144
def main(
42-
dst: T.Tensor((M*world_size), "float32"),
43-
src: T.Tensor((M*2), "float32"),
45+
dst: T.Tensor((M * world_size), "float32"),
46+
src: T.Tensor((M * 2), "float32"),
4447
):
4548
with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx):
4649
rank = T.alloc_local([1], "uint64")
@@ -49,31 +52,33 @@ def main(
4952
num_rank[0] = T.get_num_ranks()
5053
tid = T.get_thread_binding()
5154
if tid == 0:
52-
T.print(T.cast(rank[0],"int32"),msg="signal")
53-
T.print(T.cast(num_rank[0],"int32"),msg="signal")
54-
for k in T.serial(world_size//2): # 2 node
55+
T.print(T.cast(rank[0], "int32"), msg="signal")
56+
T.print(T.cast(num_rank[0], "int32"), msg="signal")
57+
for k in T.serial(world_size // 2): # 2 node
5558
T.put_block(
5659
src=T.address_of(src[bx * block_M]),
57-
dst=T.address_of(dst[bx * block_M + rank[0]*M]),
60+
dst=T.address_of(dst[bx * block_M + rank[0] * M]),
5861
size=block_M,
5962
dst_pe=k,
6063
)
6164
T.put_block(
6265
src=T.address_of(src[bx * block_M + M]),
63-
dst=T.address_of(dst[bx * block_M + M*num_rank[0] + rank[0]*M]),
66+
dst=T.address_of(dst[bx * block_M + M * num_rank[0] + rank[0] * M]),
6467
size=block_M,
6568
dst_pe=k,
6669
)
6770

6871
return main
6972

73+
7074
if __name__ == '__main__':
7175
tilelang.disable_cache()
7276

7377
M = 2
7478
K = 12288
75-
#for 2 node(16 GPUs), world_size=16,rank is 0-15,local rank is 0-7
76-
WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed(return_tp_group=True,return_lc_group=True)
79+
#for 2 node(16 GPUs), world_size=16,rank is 0-15,local rank is 0-7
80+
WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed(
81+
return_tp_group=True, return_lc_group=True)
7782
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
7883
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
7984

@@ -84,7 +89,7 @@ def main(
8489
local_rank=LOCAL_RANK,
8590
num_local_ranks=local_world_size,
8691
group=LC_GROUP)
87-
print(local_world_size,LOCAL_RANK)
92+
print(local_world_size, LOCAL_RANK)
8893

8994
# Each rank sends the local_tensor to ranks of other nodes with the same local_rank
9095
# Assuming there are 2 nodes, each with 4 workers
@@ -93,19 +98,17 @@ def main(
9398
# 2-th local tensor ([2] -> [6]), 6-th local tensor ([6] -> [2])
9499
# 3-th local tensor ([3] -> [7]), 7-th local tensor ([7] -> [3])
95100
interkernel = tilelang.compile(internode_gather(M, local_world_size, M, 128))
96-
if LOCAL_RANK==0:
101+
if LOCAL_RANK == 0:
97102
print(interkernel.get_kernel_source())
98-
src = pynvshmem.nvshmem_create_tensor(
99-
[M], torch.float32)
100-
dst = pynvshmem.nvshmem_create_tensor(
101-
[M], torch.float32)
103+
src = pynvshmem.nvshmem_create_tensor([M], torch.float32)
104+
dst = pynvshmem.nvshmem_create_tensor([M], torch.float32)
102105
input_data = torch.ones([M], dtype=torch.float32, device='cuda') * RANK
103106
src.copy_(input_data)
104107

105-
pynvshmem.nvshmem_barrier_all()
108+
pynvshmem.nvshmem_barrier_all()
106109
dist.barrier(TP_GROUP)
107110
interkernel(dst, src)
108-
pynvshmem.nvshmem_barrier_all()
111+
pynvshmem.nvshmem_barrier_all()
109112

110113
# Each rank sends the local_tensor and the received internode tensors to intranode ranks.
111114
# 0-th and 4-th local tensors ([0]->[1,2,3])
@@ -116,24 +119,30 @@ def main(
116119
# 1-th and 5-th local tensors ([5]->[4,6,7])
117120
# 2-th and 6-th local tensors ([6]->[4,5,7])
118121
# 3-th and 7-th local tensors ([7]->[4,5,6])
119-
src_intra = tilelang.tensor((M*2), torch.float32, allocator=allocator).normal_()
120-
dst_intra = tilelang.tensor((M*WORLD_SIZE), torch.float32, allocator=allocator)
121-
if RANK<WORLD_SIZE/2:
122-
cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
123-
cudart.cudaMemcpy(src_intra.data_ptr()+M*4, dst.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
122+
src_intra = tilelang.tensor((M * 2), torch.float32, allocator=allocator).normal_()
123+
dst_intra = tilelang.tensor((M * WORLD_SIZE), torch.float32, allocator=allocator)
124+
if RANK < WORLD_SIZE / 2:
125+
cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M * 4,
126+
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
127+
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, dst.data_ptr(), M * 4,
128+
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
124129
else:
125-
cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
126-
cudart.cudaMemcpy(src_intra.data_ptr()+M*4, src.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
130+
cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M * 4,
131+
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
132+
cudart.cudaMemcpy(src_intra.data_ptr() + M * 4, src.data_ptr(), M * 4,
133+
cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice)
127134

128-
env.USE_NVSHMEM=False
129-
intrakernel = tilelang.compile(intranode_gather(M, WORLD_SIZE, M, 128),pass_configs={tilelang.PassConfigKey.TL_DISABLE_RDC: True})
135+
env.USE_NVSHMEM = False
136+
intrakernel = tilelang.compile(
137+
intranode_gather(M, WORLD_SIZE, M, 128),
138+
pass_configs={tilelang.PassConfigKey.TL_DISABLE_RDC: True})
130139
intrakernel.initialize(allocator=allocator)
131-
if LOCAL_RANK==0:
140+
if LOCAL_RANK == 0:
132141
print(intrakernel.get_kernel_source())
133142
torch.cuda.synchronize()
134143
torch.distributed.barrier(LC_GROUP)
135144
intrakernel(dst_intra, src_intra)
136145
torch.cuda.synchronize()
137146
torch.distributed.barrier(LC_GROUP)
138147

139-
print(dst_intra)
148+
print(dst_intra)

tilelang/distributed/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def init_distributed(return_tp_group=False, init_nvshmem=True, return_lc_group=F
6363

6464
torch.distributed.init_process_group(
6565
backend="nccl",
66-
device_id=torch.device(f'cuda:{LOCAL_RANK}'),
66+
device_id=torch.device(f'cuda:{LOCAL_RANK}'),
6767
world_size=WORLD_SIZE,
6868
rank=RANK,
6969
timeout=datetime.timedelta(seconds=1800),
@@ -80,8 +80,9 @@ def init_distributed(return_tp_group=False, init_nvshmem=True, return_lc_group=F
8080
if return_lc_group:
8181
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
8282
base = (RANK // local_world_size) * local_world_size
83-
LC_GROUP = torch.distributed.new_group(list(range(base, base + local_world_size)), backend="nccl")
84-
print(local_world_size,LC_GROUP,TP_GROUP)
83+
LC_GROUP = torch.distributed.new_group(
84+
list(range(base, base + local_world_size)), backend="nccl")
85+
8586
return WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP
8687
elif return_tp_group:
8788
return WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP

0 commit comments

Comments
 (0)