1010from tilelang .env import env
1111from packaging import version
1212import importlib .metadata
13+
1314cuda_python_version = importlib .metadata .version ("cuda-python" )
1415if version .parse (cuda_python_version ) >= version .parse ("12.8.0" ):
1516 from cuda .bindings import driver as cuda
1920# NODES=2 NODE_RANK=0 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
2021# NODES=2 NODE_RANK=1 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py
2122
23+
2224def internode_gather (M , local_world_size , block_M , threads ):
2325
2426 @T .prim_func
@@ -28,19 +30,20 @@ def main(
2830 ):
2931 with T .Kernel (T .ceildiv (M , block_M ), threads = threads ) as (bx ):
3032 rank = T .alloc_local ([1 ], "uint64" )
31- rank [0 ] = (T .get_pe ()+ local_world_size )% ( 2 * local_world_size ) # 2 nodes
33+ rank [0 ] = (T .get_pe () + local_world_size ) % ( 2 * local_world_size ) # 2 nodes
3234 T .putmem_nbi_block (
33- T .address_of (dst [bx * block_M ]), T .address_of (src [bx * block_M ]),
34- block_M * 4 , rank [0 ])
35+ T .address_of (dst [bx * block_M ]), T .address_of (src [bx * block_M ]), block_M * 4 ,
36+ rank [0 ])
3537
3638 return main
3739
40+
3841def intranode_gather (M , world_size , block_M , threads ):
3942
4043 @T .prim_func
4144 def main (
42- dst : T .Tensor ((M * world_size ), "float32" ),
43- src : T .Tensor ((M * 2 ), "float32" ),
45+ dst : T .Tensor ((M * world_size ), "float32" ),
46+ src : T .Tensor ((M * 2 ), "float32" ),
4447 ):
4548 with T .Kernel (T .ceildiv (M , block_M ), threads = threads ) as (bx ):
4649 rank = T .alloc_local ([1 ], "uint64" )
@@ -49,31 +52,33 @@ def main(
4952 num_rank [0 ] = T .get_num_ranks ()
5053 tid = T .get_thread_binding ()
5154 if tid == 0 :
52- T .print (T .cast (rank [0 ],"int32" ),msg = "signal" )
53- T .print (T .cast (num_rank [0 ],"int32" ),msg = "signal" )
54- for k in T .serial (world_size // 2 ): # 2 node
55+ T .print (T .cast (rank [0 ], "int32" ), msg = "signal" )
56+ T .print (T .cast (num_rank [0 ], "int32" ), msg = "signal" )
57+ for k in T .serial (world_size // 2 ): # 2 node
5558 T .put_block (
5659 src = T .address_of (src [bx * block_M ]),
57- dst = T .address_of (dst [bx * block_M + rank [0 ]* M ]),
60+ dst = T .address_of (dst [bx * block_M + rank [0 ] * M ]),
5861 size = block_M ,
5962 dst_pe = k ,
6063 )
6164 T .put_block (
6265 src = T .address_of (src [bx * block_M + M ]),
63- dst = T .address_of (dst [bx * block_M + M * num_rank [0 ] + rank [0 ]* M ]),
66+ dst = T .address_of (dst [bx * block_M + M * num_rank [0 ] + rank [0 ] * M ]),
6467 size = block_M ,
6568 dst_pe = k ,
6669 )
6770
6871 return main
6972
73+
7074if __name__ == '__main__' :
7175 tilelang .disable_cache ()
7276
7377 M = 2
7478 K = 12288
75- #for 2 node(16 GPUs), world_size=16,rank is 0-15,local rank is 0-7
76- WORLD_SIZE , RANK , LOCAL_RANK , TP_GROUP , LC_GROUP = init_distributed (return_tp_group = True ,return_lc_group = True )
79+ #for 2 node(16 GPUs), world_size=16,rank is 0-15,local rank is 0-7
80+ WORLD_SIZE , RANK , LOCAL_RANK , TP_GROUP , LC_GROUP = init_distributed (
81+ return_tp_group = True , return_lc_group = True )
7782 local_world_size = int (os .environ .get ('LOCAL_WORLD_SIZE' , 1 ))
7883 LOCAL_RANK = int (os .environ .get ("LOCAL_RANK" , 0 ))
7984
@@ -84,7 +89,7 @@ def main(
8489 local_rank = LOCAL_RANK ,
8590 num_local_ranks = local_world_size ,
8691 group = LC_GROUP )
87- print (local_world_size ,LOCAL_RANK )
92+ print (local_world_size , LOCAL_RANK )
8893
8994 # Each rank sends the local_tensor to ranks of other nodes with the same local_rank
9095 # Assuming there are 2 nodes, each with 4 workers
@@ -93,19 +98,17 @@ def main(
9398 # 2-th local tensor ([2] -> [6]), 6-th local tensor ([6] -> [2])
9499 # 3-th local tensor ([3] -> [7]), 7-th local tensor ([7] -> [3])
95100 interkernel = tilelang .compile (internode_gather (M , local_world_size , M , 128 ))
96- if LOCAL_RANK == 0 :
101+ if LOCAL_RANK == 0 :
97102 print (interkernel .get_kernel_source ())
98- src = pynvshmem .nvshmem_create_tensor (
99- [M ], torch .float32 )
100- dst = pynvshmem .nvshmem_create_tensor (
101- [M ], torch .float32 )
103+ src = pynvshmem .nvshmem_create_tensor ([M ], torch .float32 )
104+ dst = pynvshmem .nvshmem_create_tensor ([M ], torch .float32 )
102105 input_data = torch .ones ([M ], dtype = torch .float32 , device = 'cuda' ) * RANK
103106 src .copy_ (input_data )
104107
105- pynvshmem .nvshmem_barrier_all ()
108+ pynvshmem .nvshmem_barrier_all ()
106109 dist .barrier (TP_GROUP )
107110 interkernel (dst , src )
108- pynvshmem .nvshmem_barrier_all ()
111+ pynvshmem .nvshmem_barrier_all ()
109112
110113 # Each rank sends the local_tensor and the received internode tensors to intranode ranks.
111114 # 0-th and 4-th local tensors ([0]->[1,2,3])
@@ -116,24 +119,30 @@ def main(
116119 # 1-th and 5-th local tensors ([5]->[4,6,7])
117120 # 2-th and 6-th local tensors ([6]->[4,5,7])
118121 # 3-th and 7-th local tensors ([7]->[4,5,6])
119- src_intra = tilelang .tensor ((M * 2 ), torch .float32 , allocator = allocator ).normal_ ()
120- dst_intra = tilelang .tensor ((M * WORLD_SIZE ), torch .float32 , allocator = allocator )
121- if RANK < WORLD_SIZE / 2 :
122- cudart .cudaMemcpy (src_intra .data_ptr (), src .data_ptr (), M * 4 , cudart .cudaMemcpyKind .cudaMemcpyDeviceToDevice )
123- cudart .cudaMemcpy (src_intra .data_ptr ()+ M * 4 , dst .data_ptr (), M * 4 , cudart .cudaMemcpyKind .cudaMemcpyDeviceToDevice )
122+ src_intra = tilelang .tensor ((M * 2 ), torch .float32 , allocator = allocator ).normal_ ()
123+ dst_intra = tilelang .tensor ((M * WORLD_SIZE ), torch .float32 , allocator = allocator )
124+ if RANK < WORLD_SIZE / 2 :
125+ cudart .cudaMemcpy (src_intra .data_ptr (), src .data_ptr (), M * 4 ,
126+ cudart .cudaMemcpyKind .cudaMemcpyDeviceToDevice )
127+ cudart .cudaMemcpy (src_intra .data_ptr () + M * 4 , dst .data_ptr (), M * 4 ,
128+ cudart .cudaMemcpyKind .cudaMemcpyDeviceToDevice )
124129 else :
125- cudart .cudaMemcpy (src_intra .data_ptr (), dst .data_ptr (), M * 4 , cudart .cudaMemcpyKind .cudaMemcpyDeviceToDevice )
126- cudart .cudaMemcpy (src_intra .data_ptr ()+ M * 4 , src .data_ptr (), M * 4 , cudart .cudaMemcpyKind .cudaMemcpyDeviceToDevice )
130+ cudart .cudaMemcpy (src_intra .data_ptr (), dst .data_ptr (), M * 4 ,
131+ cudart .cudaMemcpyKind .cudaMemcpyDeviceToDevice )
132+ cudart .cudaMemcpy (src_intra .data_ptr () + M * 4 , src .data_ptr (), M * 4 ,
133+ cudart .cudaMemcpyKind .cudaMemcpyDeviceToDevice )
127134
128- env .USE_NVSHMEM = False
129- intrakernel = tilelang .compile (intranode_gather (M , WORLD_SIZE , M , 128 ),pass_configs = {tilelang .PassConfigKey .TL_DISABLE_RDC : True })
135+ env .USE_NVSHMEM = False
136+ intrakernel = tilelang .compile (
137+ intranode_gather (M , WORLD_SIZE , M , 128 ),
138+ pass_configs = {tilelang .PassConfigKey .TL_DISABLE_RDC : True })
130139 intrakernel .initialize (allocator = allocator )
131- if LOCAL_RANK == 0 :
140+ if LOCAL_RANK == 0 :
132141 print (intrakernel .get_kernel_source ())
133142 torch .cuda .synchronize ()
134143 torch .distributed .barrier (LC_GROUP )
135144 intrakernel (dst_intra , src_intra )
136145 torch .cuda .synchronize ()
137146 torch .distributed .barrier (LC_GROUP )
138147
139- print (dst_intra )
148+ print (dst_intra )
0 commit comments