|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import os |
| 4 | +from typing import Any |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch.distributed as dist |
| 8 | +import torch.distributed._symmetric_memory as symm_mem |
| 9 | + |
| 10 | +import helion |
| 11 | +import helion.language as hl |
| 12 | + |
| 13 | + |
| 14 | +def copy_engine_all_gather_w_progress( |
| 15 | + output: torch.Tensor, |
| 16 | + inp: torch.Tensor, # Must be symmetric tensor |
| 17 | + progress: torch.Tensor, |
| 18 | + splits_per_rank: int, |
| 19 | + backend_stream: torch.cuda.Stream | None = None, |
| 20 | +) -> torch.cuda.Stream: |
| 21 | + backend_stream = symm_mem._get_backend_stream(priority=-1) |
| 22 | + assert inp.is_contiguous() |
| 23 | + symm_mem_group = dist.group.WORLD |
| 24 | + if symm_mem_group is None: |
| 25 | + raise RuntimeError("No symmetric memory group available") |
| 26 | + symm_mem_hdl = symm_mem.rendezvous(inp, group=symm_mem_group) |
| 27 | + assert symm_mem_hdl is not None |
| 28 | + |
| 29 | + rank = symm_mem_hdl.rank |
| 30 | + world_size = symm_mem_hdl.world_size |
| 31 | + |
| 32 | + assert inp.numel() % splits_per_rank == 0 |
| 33 | + assert progress.numel() >= world_size * splits_per_rank |
| 34 | + |
| 35 | + output_shape = list(inp.shape) |
| 36 | + output_shape[0] *= world_size |
| 37 | + assert list(output.shape) == output_shape, (list(output.shape), output_shape) |
| 38 | + |
| 39 | + chunks = output.chunk(world_size * splits_per_rank) |
| 40 | + |
| 41 | + symm_mem_hdl.barrier() |
| 42 | + backend_stream.wait_stream(torch.cuda.current_stream()) |
| 43 | + |
| 44 | + with torch.cuda.stream(backend_stream): |
| 45 | + for step in range(world_size): |
| 46 | + src_rank = (rank + step + 1) % world_size |
| 47 | + for split_id in range(splits_per_rank): |
| 48 | + src_buf = symm_mem_hdl.get_buffer( |
| 49 | + src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id |
| 50 | + ) |
| 51 | + chunks[src_rank * splits_per_rank + split_id].copy_(src_buf) |
| 52 | + # cuStreamWriteValue32 issues a system level fence before the write |
| 53 | + symm_mem_hdl.stream_write_value32( |
| 54 | + progress, |
| 55 | + offset=src_rank * splits_per_rank + split_id, |
| 56 | + val=1, |
| 57 | + ) |
| 58 | + symm_mem_hdl.barrier() |
| 59 | + |
| 60 | + return backend_stream |
| 61 | + |
| 62 | + |
| 63 | +@helion.jit( |
| 64 | + config=helion.Config( |
| 65 | + block_sizes=[128, 256, 64], |
| 66 | + num_warps=8, |
| 67 | + num_stages=3, |
| 68 | + indexing="block_ptr", |
| 69 | + ), |
| 70 | + # Static shapes provides a speedup for attention |
| 71 | + static_shapes=True, |
| 72 | +) |
| 73 | +def helion_matmul_w_progress( |
| 74 | + a: torch.Tensor, |
| 75 | + a_shared: torch.Tensor, |
| 76 | + b: torch.Tensor, |
| 77 | + progress: torch.Tensor, |
| 78 | + SPLITS_PER_RANK: int, |
| 79 | + RANK: int, |
| 80 | +) -> torch.Tensor: |
| 81 | + M, K = a.size() |
| 82 | + K2, N = b.size() |
| 83 | + assert K2 == K, f"size mismatch {K2} != {K}" |
| 84 | + |
| 85 | + out = torch.empty( |
| 86 | + [M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device |
| 87 | + ) |
| 88 | + |
| 89 | + M_per_rank = a_shared.size(0) |
| 90 | + |
| 91 | + for tile_m, tile_n in hl.tile([M, N]): |
| 92 | + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 93 | + hl.wait( |
| 94 | + progress, |
| 95 | + [ |
| 96 | + tile_m.begin // (M_per_rank // SPLITS_PER_RANK), |
| 97 | + ], |
| 98 | + signal=1, |
| 99 | + update=None, |
| 100 | + op="ld", |
| 101 | + scope="gpu", |
| 102 | + sem="acquire", |
| 103 | + ) |
| 104 | + for tile_k in hl.tile(K): |
| 105 | + # TODO(joydddd): use a_shared and skipp barrier when data is available on local rank. |
| 106 | + # if tile_k.begin // M_per_rank == RANK: |
| 107 | + # acc = torch.addmm(acc, a_shared[tile_m.index - RANK * M_per_rank, tile_k], b[tile_k, tile_n]) |
| 108 | + # else: |
| 109 | + # hl.wait(progress, [tile_m.begin // (M_per_rank // SPLITS_PER_RANK),], signal=1, update=None, op="ld", scope="gpu", sem="acquire") |
| 110 | + acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n]) |
| 111 | + out[tile_m, tile_n] = acc |
| 112 | + return out |
| 113 | + |
| 114 | + |
| 115 | +def helion_all_gather_matmul( |
| 116 | + a_shared: torch.Tensor, |
| 117 | + b: torch.Tensor, |
| 118 | + a_out: torch.Tensor | None = None, |
| 119 | + progress: torch.Tensor | None = None, |
| 120 | + **kwargs: Any, |
| 121 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 122 | + configs = { |
| 123 | + "SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1), |
| 124 | + "BLOCK_SIZE_M": kwargs.get("block_size_m", 128), |
| 125 | + "BLOCK_SIZE_N": kwargs.get("block_size_n", 256), |
| 126 | + "BLOCK_SIZE_K": kwargs.get("block_size_k", 64), |
| 127 | + "GROUP_SIZE_M": kwargs.get("group_size_m", 4), |
| 128 | + "num_stages": kwargs.get("num_stages", 3), |
| 129 | + "num_warps": kwargs.get("num_warps", 8), |
| 130 | + } |
| 131 | + |
| 132 | + symm_mem_group = dist.group.WORLD |
| 133 | + if symm_mem_group is None: |
| 134 | + raise RuntimeError("No symmetric memory group available") |
| 135 | + |
| 136 | + symm_mem_hdl = symm_mem.rendezvous(a_shared, group=symm_mem_group) |
| 137 | + |
| 138 | + a_shape = list(a_shared.shape) |
| 139 | + a_shape[0] *= symm_mem_hdl.world_size |
| 140 | + |
| 141 | + configs["RANK"] = symm_mem_hdl.rank |
| 142 | + configs["WORLD_SIZE"] = symm_mem_hdl.world_size |
| 143 | + if ( |
| 144 | + configs["SPLITS_PER_RANK"] |
| 145 | + * configs["WORLD_SIZE"] |
| 146 | + * configs["BLOCK_SIZE_M"] |
| 147 | + * configs["GROUP_SIZE_M"] |
| 148 | + > a_shape[0] |
| 149 | + ): |
| 150 | + configs["GROUP_SIZE_M"] = 1 |
| 151 | + configs["SPLITS_PER_RANK"] = 1 |
| 152 | + |
| 153 | + configs["COMM_BLOCK_SIZE_M"] = ( |
| 154 | + a_shape[0] // configs["WORLD_SIZE"] // configs["SPLITS_PER_RANK"] |
| 155 | + ) |
| 156 | + assert ( |
| 157 | + configs["COMM_BLOCK_SIZE_M"] |
| 158 | + % (configs["BLOCK_SIZE_M"] * configs["GROUP_SIZE_M"]) |
| 159 | + == 0 |
| 160 | + ) |
| 161 | + |
| 162 | + if a_out is None: |
| 163 | + a_out = torch.empty(a_shape, dtype=a_shared.dtype, device=a_shared.device) |
| 164 | + |
| 165 | + if progress is None: |
| 166 | + progress = torch.zeros( |
| 167 | + symm_mem_hdl.world_size * configs["SPLITS_PER_RANK"], |
| 168 | + dtype=torch.uint32, |
| 169 | + device=a_shared.device, |
| 170 | + ) |
| 171 | + else: |
| 172 | + progress.fill_( |
| 173 | + 0 |
| 174 | + ) # Reset progress to 0. Maybe we should reset inside the kernel using cas? |
| 175 | + |
| 176 | + backend_stream = copy_engine_all_gather_w_progress( |
| 177 | + a_out, a_shared, progress, configs["SPLITS_PER_RANK"] |
| 178 | + ) |
| 179 | + |
| 180 | + c = helion_matmul_w_progress( |
| 181 | + a_out, |
| 182 | + a_shared, |
| 183 | + b, |
| 184 | + progress, |
| 185 | + SPLITS_PER_RANK=configs["SPLITS_PER_RANK"], |
| 186 | + RANK=configs["RANK"], |
| 187 | + ) |
| 188 | + assert type(c) is torch.Tensor |
| 189 | + |
| 190 | + torch.cuda.current_stream().wait_stream(backend_stream) |
| 191 | + |
| 192 | + return a_out, c |
| 193 | + |
| 194 | + |
| 195 | +def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None: |
| 196 | + a_shared = symm_mem.empty( |
| 197 | + M // world_size, K, dtype=torch.bfloat16, device=device |
| 198 | + ).normal_() |
| 199 | + b = torch.randn((K, N), device="cuda", dtype=torch.bfloat16).T.contiguous().T |
| 200 | + |
| 201 | + a_out, c = helion_all_gather_matmul(a_shared, b) |
| 202 | + |
| 203 | + golden_a = a_shared.clone() |
| 204 | + dist_group = dist.group.WORLD |
| 205 | + if dist_group is None: |
| 206 | + raise RuntimeError("No distributed group available") |
| 207 | + ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul( |
| 208 | + golden_a, [b], gather_dim=0, group_name=dist_group.group_name |
| 209 | + ) |
| 210 | + torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1) |
| 211 | + torch.testing.assert_close(a_out, ag_golden) |
| 212 | + |
| 213 | + |
| 214 | +def main() -> None: |
| 215 | + rank = int(os.environ["LOCAL_RANK"]) |
| 216 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 217 | + torch.manual_seed(42 + rank) |
| 218 | + device = torch.device(f"cuda:{rank}") |
| 219 | + torch.cuda.set_device(device) |
| 220 | + dist.init_process_group("nccl") |
| 221 | + test(4096, 6656, 16384, world_size, device) |
| 222 | + |
| 223 | + dist.destroy_process_group() |
| 224 | + |
| 225 | + |
| 226 | +if __name__ == "__main__": |
| 227 | + """ |
| 228 | + torchrun \ |
| 229 | + --nnodes 1 --nproc-per-node 8 \ |
| 230 | + --rdzv-backend c10d --rdzv-endpoint localhost:0 \ |
| 231 | + --no_python python3 examples/all_gather_matmul.py |
| 232 | + """ |
| 233 | + main() |
0 commit comments