|
| 1 | +import os |
| 2 | +import tilelang |
| 3 | +import tilelang.language as T |
| 4 | +import argparse |
| 5 | +import torch |
| 6 | +import torch.distributed as dist |
| 7 | +import torch.multiprocessing |
| 8 | +from tilelang.distributed.utils import init_dist |
| 9 | +from cuda import cudart |
| 10 | +from tilelang.distributed.utils import set_signal, wait_eq |
| 11 | + |
| 12 | +tilelang.disable_cache() |
| 13 | +os.environ['NCCL_DEBUG'] = 'WARN' # silence NCCL log |
| 14 | + |
| 15 | + |
| 16 | +def gemm_kernel(M, |
| 17 | + N, |
| 18 | + K, |
| 19 | + num_rank, |
| 20 | + block_M, |
| 21 | + block_N, |
| 22 | + block_K, |
| 23 | + threads, |
| 24 | + dtype="float16", |
| 25 | + accum_dtype="float"): |
| 26 | + |
| 27 | + @T.prim_func |
| 28 | + def main( |
| 29 | + A: T.Tensor((M, K), dtype), |
| 30 | + B: T.Tensor((K, N // num_rank), dtype), |
| 31 | + C: T.Tensor((M, N // num_rank), dtype), |
| 32 | + ): |
| 33 | + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): |
| 34 | + A_shared = T.alloc_shared((block_M, block_K), dtype) |
| 35 | + B_shared = T.alloc_shared((block_K, block_N), dtype) |
| 36 | + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 37 | + |
| 38 | + T.clear(C_local) |
| 39 | + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): |
| 40 | + T.copy(A[by * block_M, k * block_K], A_shared) |
| 41 | + T.copy(B[k * block_K, bx * block_N], B_shared) |
| 42 | + T.gemm(A_shared, B_shared, C_local) |
| 43 | + T.copy(C_local, C[by * block_M, bx * block_N]) |
| 44 | + |
| 45 | + return main |
| 46 | + |
| 47 | + |
| 48 | +def cp_engine_producer_all_gather_put(local_tensor, ag_buffer, signal_buffer, M_per_rank, N, |
| 49 | + signal_target, rank, local_world_size, world_size, |
| 50 | + intranode_ag_stream): |
| 51 | + local_rank = rank % local_world_size |
| 52 | + n_nodes = world_size // local_world_size |
| 53 | + node_rank = rank // local_world_size |
| 54 | + |
| 55 | + for i in range(1, local_world_size): |
| 56 | + segment = rank * M_per_rank * N |
| 57 | + local_dst_rank = (local_rank + local_world_size - i) % local_world_size |
| 58 | + src_ptr = ag_buffer[local_rank].data_ptr() + segment * local_tensor.element_size() |
| 59 | + dst_ptr = ag_buffer[local_dst_rank].data_ptr() + segment * local_tensor.element_size() |
| 60 | + # Using copy engine to perform intranode transmission |
| 61 | + # Sending rank-th local tensor to other ranks inside the node. |
| 62 | + (err,) = cudart.cudaMemcpyAsync( |
| 63 | + dst_ptr, |
| 64 | + src_ptr, |
| 65 | + M_per_rank * N * local_tensor.element_size(), |
| 66 | + cudart.cudaMemcpyKind.cudaMemcpyDefault, |
| 67 | + intranode_ag_stream.cuda_stream, |
| 68 | + ) |
| 69 | + # Notify the peer that the transmission is done. |
| 70 | + set_signal(signal_buffer[local_dst_rank][rank], signal_target, intranode_ag_stream) |
| 71 | + |
| 72 | + for i in range(1, n_nodes): |
| 73 | + recv_rank = local_rank + (node_rank + n_nodes - i) % n_nodes * local_world_size |
| 74 | + recv_segment = recv_rank * M_per_rank * N |
| 75 | + # Waiting for the internode data ready |
| 76 | + wait_eq(signal_buffer[local_rank][recv_rank], signal_target, intranode_ag_stream) |
| 77 | + src_ptr = ag_buffer[local_rank].data_ptr() + recv_segment * local_tensor.element_size() |
| 78 | + for j in range(1, local_world_size): |
| 79 | + local_dst_rank = (local_rank + local_world_size - j) % local_world_size |
| 80 | + dst_ptr = ag_buffer[local_dst_rank].data_ptr( |
| 81 | + ) + recv_segment * local_tensor.element_size() |
| 82 | + # Sending (local_rank + j*local_world_size) % world_size -th local tensor to other ranks inside the node. |
| 83 | + (err,) = cudart.cudaMemcpyAsync( |
| 84 | + dst_ptr, |
| 85 | + src_ptr, |
| 86 | + M_per_rank * N * local_tensor.element_size(), |
| 87 | + cudart.cudaMemcpyKind.cudaMemcpyDefault, |
| 88 | + intranode_ag_stream.cuda_stream, |
| 89 | + ) |
| 90 | + # Notify the peer that the transmission is done. |
| 91 | + set_signal(signal_buffer[local_dst_rank][recv_rank], signal_target, intranode_ag_stream) |
| 92 | + |
| 93 | + |
| 94 | +def ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, N, signal_target, rank, group, |
| 95 | + local_world_size, world_size, gemm_kernel, ag_stream): |
| 96 | + |
| 97 | + dist.barrier(group) |
| 98 | + |
| 99 | + # all_gather A to ag_buffer |
| 100 | + with torch.cuda.stream(ag_stream): |
| 101 | + cp_engine_producer_all_gather_put(A, ag_buffer, signal_buffer, M_per_rank, N, signal_target, |
| 102 | + rank, local_world_size, world_size, ag_stream) |
| 103 | + |
| 104 | + current_stream = torch.cuda.current_stream() |
| 105 | + current_stream.wait_stream(ag_stream) |
| 106 | + |
| 107 | + dist.barrier(group) |
| 108 | + torch.cuda.synchronize() |
| 109 | + |
| 110 | + torch.cuda.synchronize() |
| 111 | + torch.distributed.barrier(group) |
| 112 | + gemm_kernel(ag_buffer[rank], B, C) |
| 113 | + torch.cuda.synchronize() |
| 114 | + torch.distributed.barrier(group) |
| 115 | + |
| 116 | + return C |
| 117 | + |
| 118 | + |
| 119 | +def torch_ag_gemm( |
| 120 | + pg: torch.distributed.ProcessGroup, |
| 121 | + local_input: torch.Tensor, |
| 122 | + local_weight: torch.Tensor, |
| 123 | + ag_out: torch.Tensor, |
| 124 | +): |
| 125 | + torch.distributed.all_gather_into_tensor(ag_out, local_input, pg) |
| 126 | + ag_gemm_output = torch.matmul(ag_out, local_weight) |
| 127 | + return ag_gemm_output |
| 128 | + |
| 129 | + |
| 130 | +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): |
| 131 | + dtype = torch.float16 |
| 132 | + M = args.M if args else 8192 |
| 133 | + N = args.N if args else 8192 |
| 134 | + K = args.K if args else 8192 |
| 135 | + M_per_rank = M // num_local_ranks |
| 136 | + N_per_rank = N // num_local_ranks |
| 137 | + |
| 138 | + BLOCK_M = 128 |
| 139 | + BLOCK_N = 128 |
| 140 | + BLOCK_K = 64 |
| 141 | + threads = 256 |
| 142 | + assert num_local_ranks == 2, "this example only supports 2 ranks copying to each other" |
| 143 | + |
| 144 | + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) |
| 145 | + allocator = tilelang.get_allocator( |
| 146 | + size=2**30, |
| 147 | + device="cuda", |
| 148 | + is_distributed=True, |
| 149 | + local_rank=local_rank, |
| 150 | + num_local_ranks=num_local_ranks, |
| 151 | + group=group) |
| 152 | + kernel = tilelang.compile(gemm_kernel(M, N, K, num_ranks, BLOCK_M, BLOCK_N, BLOCK_K, threads)) |
| 153 | + kernel.initialize(allocator=allocator) |
| 154 | + if local_rank == 0: |
| 155 | + print(kernel.get_kernel_source()) |
| 156 | + |
| 157 | + A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_() |
| 158 | + B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_() |
| 159 | + C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator) |
| 160 | + ag_buffer = tilelang.tensor((M, K), dtype, allocator=allocator, return_peers=True) |
| 161 | + signal_buffer = tilelang.tensor((num_local_ranks,), |
| 162 | + torch.int32, |
| 163 | + allocator=allocator, |
| 164 | + return_peers=True) |
| 165 | + signal_buffer[rank].fill_(0) |
| 166 | + ag_buffer[rank][rank * M_per_rank:(rank + 1) * M_per_rank, :].copy_(A) |
| 167 | + |
| 168 | + dist.barrier(group) |
| 169 | + |
| 170 | + ag_stream = torch.cuda.Stream() |
| 171 | + signal_target = 1 |
| 172 | + |
| 173 | + tilelang_C = ag_gemm_op(A, B, C, ag_buffer, signal_buffer, M_per_rank, K, signal_target, rank, |
| 174 | + group, num_local_ranks, num_local_ranks, kernel, ag_stream) |
| 175 | + |
| 176 | + torch_ag_buffer = torch.empty([M, K], dtype=dtype, device="cuda") |
| 177 | + torch_C = torch_ag_gemm(group, A, B, torch_ag_buffer) |
| 178 | + |
| 179 | + if torch.allclose(torch_C, tilelang_C, atol=1e-6, rtol=1e-6): |
| 180 | + print(f"rank {local_rank} check passed.✅") |
| 181 | + else: |
| 182 | + print(f"rank {local_rank} check failed.❌") |
| 183 | + print(f"torch_C: {torch_C}, tilelang_C: {tilelang_C}") |
| 184 | + raise ValueError("Test failed") |
| 185 | + |
| 186 | + dist.destroy_process_group() |
| 187 | + |
| 188 | + |
| 189 | +if __name__ == "__main__": |
| 190 | + parser = argparse.ArgumentParser() |
| 191 | + parser.add_argument( |
| 192 | + '--num-processes', type=int, default=2, help='Number of processes to spawn (default: 2)') |
| 193 | + parser.add_argument('--M', type=int, default=8192, help='M dimension') |
| 194 | + parser.add_argument('--N', type=int, default=8192, help='N dimension') |
| 195 | + parser.add_argument('--K', type=int, default=8192, help='K dimension') |
| 196 | + args = parser.parse_args() |
| 197 | + num_processes = args.num_processes |
| 198 | + |
| 199 | + torch.multiprocessing.spawn(main, args=(num_processes, args), nprocs=num_processes) |
0 commit comments