|
| 1 | +import argparse |
| 2 | +import torch |
| 3 | +import torch.distributed as dist |
| 4 | +import pynvshmem |
| 5 | +import tilelang |
| 6 | +import tilelang.language as T |
| 7 | +import os |
| 8 | +from tilelang.distributed.utils import init_distributed, dtype_map, perf_fn |
| 9 | +from tilelang.distributed.utils import init_dist |
| 10 | +from tilelang.env import env |
| 11 | +from packaging import version |
| 12 | +import importlib.metadata |
| 13 | +cuda_python_version = importlib.metadata.version("cuda-python") |
| 14 | +if version.parse(cuda_python_version) >= version.parse("12.8.0"): |
| 15 | + from cuda.bindings import driver as cuda |
| 16 | + from cuda.bindings import runtime as cudart |
| 17 | +else: |
| 18 | + from cuda import cuda, cudart |
| 19 | +# NODES=2 NODE_RANK=0 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py |
| 20 | +# NODES=2 NODE_RANK=1 ARNOLD_WORKER_0_HOST=ip0 bash tilelang/distributed/launch.sh ./examples/distributed/example_overlapping_allgather.py |
| 21 | + |
| 22 | +def internode_gather(M, local_world_size, block_M, threads): |
| 23 | + |
| 24 | + @T.prim_func |
| 25 | + def main( |
| 26 | + dst: T.Tensor((M), "float32"), |
| 27 | + src: T.Tensor((M), "float32"), |
| 28 | + ): |
| 29 | + with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx): |
| 30 | + rank = T.alloc_local([1], "uint64") |
| 31 | + rank[0] = (T.get_pe()+local_world_size)%(2*local_world_size) # 2 nodes |
| 32 | + T.putmem_nbi_block( |
| 33 | + T.address_of(dst[bx * block_M]), T.address_of(src[bx * block_M]), |
| 34 | + block_M *4 , rank[0]) |
| 35 | + |
| 36 | + return main |
| 37 | + |
| 38 | +def intranode_gather(M, world_size, block_M, threads): |
| 39 | + |
| 40 | + @T.prim_func |
| 41 | + def main( |
| 42 | + dst: T.Tensor((M*world_size), "float32"), |
| 43 | + src: T.Tensor((M*2), "float32"), |
| 44 | + ): |
| 45 | + with T.Kernel(T.ceildiv(M, block_M), threads=threads) as (bx): |
| 46 | + rank = T.alloc_local([1], "uint64") |
| 47 | + num_rank = T.alloc_local([1], "uint64") |
| 48 | + rank[0] = T.get_rank() |
| 49 | + num_rank[0] = T.get_num_ranks() |
| 50 | + tid = T.get_thread_binding() |
| 51 | + if tid == 0: |
| 52 | + T.print(T.cast(rank[0],"int32"),msg="signal") |
| 53 | + T.print(T.cast(num_rank[0],"int32"),msg="signal") |
| 54 | + for k in T.serial(world_size//2): # 2 node |
| 55 | + T.put_block( |
| 56 | + src=T.address_of(src[bx * block_M]), |
| 57 | + dst=T.address_of(dst[bx * block_M + rank[0]*M]), |
| 58 | + size=block_M, |
| 59 | + dst_pe=k, |
| 60 | + ) |
| 61 | + T.put_block( |
| 62 | + src=T.address_of(src[bx * block_M + M]), |
| 63 | + dst=T.address_of(dst[bx * block_M + M*num_rank[0] + rank[0]*M]), |
| 64 | + size=block_M, |
| 65 | + dst_pe=k, |
| 66 | + ) |
| 67 | + |
| 68 | + return main |
| 69 | + |
| 70 | +if __name__ == '__main__': |
| 71 | + tilelang.disable_cache() |
| 72 | + |
| 73 | + M = 2 |
| 74 | + K = 12288 |
| 75 | + #for 2 node(16 GPUs), world_size=16,rank is 0-15,local rank is 0-7 |
| 76 | + WORLD_SIZE, RANK, LOCAL_RANK, TP_GROUP, LC_GROUP = init_distributed(return_tp_group=True,return_lc_group=True) |
| 77 | + local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1)) |
| 78 | + LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) |
| 79 | + |
| 80 | + allocator = tilelang.get_allocator( |
| 81 | + size=2**25, |
| 82 | + device="cuda", |
| 83 | + is_distributed=True, |
| 84 | + local_rank=LOCAL_RANK, |
| 85 | + num_local_ranks=local_world_size, |
| 86 | + group=LC_GROUP) |
| 87 | + print(local_world_size,LOCAL_RANK) |
| 88 | + |
| 89 | + # Each rank sends the local_tensor to ranks of other nodes with the same local_rank |
| 90 | + # Assuming there are 2 nodes, each with 4 workers |
| 91 | + # 0-th local tensor ([0] -> [4]), 4-th local tensor ([4] -> [0]) |
| 92 | + # 1-th local tensor ([1] -> [5]), 5-th local tensor ([5] -> [1]) |
| 93 | + # 2-th local tensor ([2] -> [6]), 6-th local tensor ([6] -> [2]) |
| 94 | + # 3-th local tensor ([3] -> [7]), 7-th local tensor ([7] -> [3]) |
| 95 | + interkernel = tilelang.compile(internode_gather(M, local_world_size, M, 128)) |
| 96 | + if LOCAL_RANK==0: |
| 97 | + print(interkernel.get_kernel_source()) |
| 98 | + src = pynvshmem.nvshmem_create_tensor( |
| 99 | + [M], torch.float32) |
| 100 | + dst = pynvshmem.nvshmem_create_tensor( |
| 101 | + [M], torch.float32) |
| 102 | + input_data = torch.ones([M], dtype=torch.float32, device='cuda') * RANK |
| 103 | + src.copy_(input_data) |
| 104 | + |
| 105 | + pynvshmem.nvshmem_barrier_all() |
| 106 | + dist.barrier(TP_GROUP) |
| 107 | + interkernel(dst, src) |
| 108 | + pynvshmem.nvshmem_barrier_all() |
| 109 | + |
| 110 | + # Each rank sends the local_tensor and the received internode tensors to intranode ranks. |
| 111 | + # 0-th and 4-th local tensors ([0]->[1,2,3]) |
| 112 | + # 1-th and 5-th local tensors ([1]->[0,2,3]) |
| 113 | + # 2-th and 6-th local tensors ([2]->[0,1,3]) |
| 114 | + # 3-th and 7-th local tensors ([3]->[0,1,2]) |
| 115 | + # 0-th and 4-th local tensors ([4]->[5,6,7]) |
| 116 | + # 1-th and 5-th local tensors ([5]->[4,6,7]) |
| 117 | + # 2-th and 6-th local tensors ([6]->[4,5,7]) |
| 118 | + # 3-th and 7-th local tensors ([7]->[4,5,6]) |
| 119 | + src_intra = tilelang.tensor((M*2), torch.float32, allocator=allocator).normal_() |
| 120 | + dst_intra = tilelang.tensor((M*WORLD_SIZE), torch.float32, allocator=allocator) |
| 121 | + if RANK<WORLD_SIZE/2: |
| 122 | + cudart.cudaMemcpy(src_intra.data_ptr(), src.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) |
| 123 | + cudart.cudaMemcpy(src_intra.data_ptr()+M*4, dst.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) |
| 124 | + else: |
| 125 | + cudart.cudaMemcpy(src_intra.data_ptr(), dst.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) |
| 126 | + cudart.cudaMemcpy(src_intra.data_ptr()+M*4, src.data_ptr(), M*4, cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice) |
| 127 | + |
| 128 | + env.USE_NVSHMEM=False |
| 129 | + intrakernel = tilelang.compile(intranode_gather(M, WORLD_SIZE, M, 128),pass_configs={tilelang.PassConfigKey.TL_DISABLE_RDC: True}) |
| 130 | + intrakernel.initialize(allocator=allocator) |
| 131 | + if LOCAL_RANK==0: |
| 132 | + print(intrakernel.get_kernel_source()) |
| 133 | + torch.cuda.synchronize() |
| 134 | + torch.distributed.barrier(LC_GROUP) |
| 135 | + intrakernel(dst_intra, src_intra) |
| 136 | + torch.cuda.synchronize() |
| 137 | + torch.distributed.barrier(LC_GROUP) |
| 138 | + |
| 139 | + print(dst_intra) |
0 commit comments