|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# SPDX-License-Identifier: MIT |
| 3 | +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. |
| 4 | + |
| 5 | +import json |
| 6 | +import csv |
| 7 | +import argparse |
| 8 | +from pathlib import Path |
| 9 | +import torch |
| 10 | +import triton |
| 11 | +import triton.language as tl |
| 12 | +import iris |
| 13 | +from iris._mpi_helpers import mpi_allgather |
| 14 | +from examples.common.utils import read_realtime |
| 15 | + |
| 16 | + |
| 17 | +@triton.jit() |
| 18 | +def ping_pong( |
| 19 | + data, |
| 20 | + n_elements, |
| 21 | + skip, |
| 22 | + niter, |
| 23 | + flag, |
| 24 | + curr_rank, |
| 25 | + peer_rank, |
| 26 | + BLOCK_SIZE: tl.constexpr, |
| 27 | + heap_bases: tl.tensor, |
| 28 | + mm_begin_timestamp_ptr: tl.tensor = None, |
| 29 | + mm_end_timestamp_ptr: tl.tensor = None, |
| 30 | +): |
| 31 | + pid = tl.program_id(0) |
| 32 | + block_start = pid * BLOCK_SIZE |
| 33 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 34 | + |
| 35 | + data_mask = offsets < n_elements |
| 36 | + time_stmp_mask = offsets < BLOCK_SIZE |
| 37 | + flag_mask = offsets < 1 |
| 38 | + |
| 39 | + for i in range(niter + skip): |
| 40 | + if i == skip: |
| 41 | + start = read_realtime() |
| 42 | + tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask) |
| 43 | + first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank) |
| 44 | + token_first_done = i + 1 |
| 45 | + token_second_done = i + 2 |
| 46 | + if curr_rank == first_rank: |
| 47 | + iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask) |
| 48 | + iris.store(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, mask=flag_mask) |
| 49 | + while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done: |
| 50 | + pass |
| 51 | + else: |
| 52 | + while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done: |
| 53 | + pass |
| 54 | + iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask) |
| 55 | + iris.store(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, mask=flag_mask) |
| 56 | + |
| 57 | + stop = read_realtime() |
| 58 | + tl.store(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask) |
| 59 | + |
| 60 | + |
| 61 | +def torch_dtype_from_str(datatype: str) -> torch.dtype: |
| 62 | + dtype_map = { |
| 63 | + "int8": torch.int8, |
| 64 | + "fp16": torch.float16, |
| 65 | + "bf16": torch.bfloat16, |
| 66 | + "fp32": torch.float32, |
| 67 | + "int32": torch.int32, |
| 68 | + } |
| 69 | + try: |
| 70 | + return dtype_map[datatype] |
| 71 | + except KeyError: |
| 72 | + raise ValueError(f"Unknown datatype: {datatype}") |
| 73 | + |
| 74 | + |
| 75 | +def parse_args(): |
| 76 | + parser = argparse.ArgumentParser( |
| 77 | + description="Latency ping-pong benchmark", |
| 78 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| 79 | + ) |
| 80 | + parser.add_argument( |
| 81 | + "-t", |
| 82 | + "--datatype", |
| 83 | + type=str, |
| 84 | + default="int32", |
| 85 | + choices=["int8", "fp16", "bf16", "fp32", "int32"], |
| 86 | + help="Datatype for the message payload", |
| 87 | + ) |
| 88 | + parser.add_argument( |
| 89 | + "-p", |
| 90 | + "--heap_size", |
| 91 | + type=int, |
| 92 | + default=1 << 32, |
| 93 | + help="Iris heap size", |
| 94 | + ) |
| 95 | + parser.add_argument( |
| 96 | + "-b", |
| 97 | + "--block_size", |
| 98 | + type=int, |
| 99 | + default=1, |
| 100 | + help="Block size", |
| 101 | + ) |
| 102 | + parser.add_argument( |
| 103 | + "-z", |
| 104 | + "--buffer_size", |
| 105 | + type=int, |
| 106 | + default=1, |
| 107 | + help="Length of the source buffer (elements)", |
| 108 | + ) |
| 109 | + parser.add_argument( |
| 110 | + "-i", |
| 111 | + "--iter", |
| 112 | + type=int, |
| 113 | + default=100, |
| 114 | + help="Number of timed iterations", |
| 115 | + ) |
| 116 | + parser.add_argument( |
| 117 | + "-w", |
| 118 | + "--num_warmup", |
| 119 | + type=int, |
| 120 | + default=10, |
| 121 | + help="Number of warmup (skip) iterations", |
| 122 | + ) |
| 123 | + parser.add_argument( |
| 124 | + "-o", |
| 125 | + "--output_file", |
| 126 | + type=str, |
| 127 | + default=None, |
| 128 | + help="Optional output filename (if omitted, prints results to terminal). Supports .json, .csv", |
| 129 | + ) |
| 130 | + return vars(parser.parse_args()) |
| 131 | + |
| 132 | + |
| 133 | +def _pretty_print_matrix(latency_matrix: torch.Tensor) -> None: |
| 134 | + num_ranks = latency_matrix.shape[0] |
| 135 | + col_width = 12 |
| 136 | + header = "SRC\\DST".ljust(col_width) + "".join(f"{j:>12}" for j in range(num_ranks)) |
| 137 | + print("\nLatency matrix (ns per iter):") |
| 138 | + print(header) |
| 139 | + for i in range(num_ranks): |
| 140 | + row = f"R{i}".ljust(col_width) |
| 141 | + for j in range(num_ranks): |
| 142 | + row += f"{latency_matrix[i, j].item():12.6f}" |
| 143 | + print(row) |
| 144 | + |
| 145 | + |
| 146 | +def _write_csv(path: Path, latency_matrix: torch.Tensor) -> None: |
| 147 | + path.parent.mkdir(parents=True, exist_ok=True) |
| 148 | + with path.open("w", newline="") as f: |
| 149 | + writer = csv.writer(f) |
| 150 | + num_ranks = latency_matrix.shape[0] |
| 151 | + writer.writerow([""] + [f"R{j}" for j in range(num_ranks)]) |
| 152 | + for i in range(num_ranks): |
| 153 | + row = [f"R{i}"] + [f"{latency_matrix[i, j].item():0.6f}" for j in range(num_ranks)] |
| 154 | + writer.writerow(row) |
| 155 | + |
| 156 | + |
| 157 | +def _write_json(path: Path, latency_matrix: torch.Tensor) -> None: |
| 158 | + path.parent.mkdir(parents=True, exist_ok=True) |
| 159 | + num_ranks = latency_matrix.shape[0] |
| 160 | + rows = [] |
| 161 | + for s in range(num_ranks): |
| 162 | + for d in range(num_ranks): |
| 163 | + rows.append( |
| 164 | + { |
| 165 | + "source_rank": int(s), |
| 166 | + "destination_rank": int(d), |
| 167 | + "latency_ns": float(latency_matrix[s, d].item()), |
| 168 | + } |
| 169 | + ) |
| 170 | + with path.open("w") as f: |
| 171 | + json.dump(rows, f, indent=2) |
| 172 | + |
| 173 | + |
| 174 | +def save_results(latency_matrix: torch.Tensor, out: str | None) -> None: |
| 175 | + if out is None: |
| 176 | + _pretty_print_matrix(latency_matrix) |
| 177 | + return |
| 178 | + |
| 179 | + path = Path(out) |
| 180 | + ext = path.suffix.lower() |
| 181 | + if ext == ".json": |
| 182 | + _write_json(path, latency_matrix) |
| 183 | + elif ext == ".csv": |
| 184 | + _write_csv(path, latency_matrix) |
| 185 | + else: |
| 186 | + raise ValueError(f"Unsupported output file extension: {out}") |
| 187 | + |
| 188 | + |
| 189 | + |
| 190 | +def print_run_settings( |
| 191 | + args: dict, |
| 192 | + num_ranks: int, |
| 193 | + dtype: torch.dtype, |
| 194 | + BLOCK_SIZE: int, |
| 195 | + BUFFER_LEN: int, |
| 196 | +) -> None: |
| 197 | + elem_size = torch.tensor([], dtype=dtype).element_size() |
| 198 | + heap_size = args["heap_size"] |
| 199 | + out = args["output_file"] |
| 200 | + header = "=" * 72 |
| 201 | + print(header) |
| 202 | + print("Latency benchmark -- run settings") |
| 203 | + print(header) |
| 204 | + print(f" num_ranks : {num_ranks}") |
| 205 | + print(f" iterations : {args['iter']} (timed)") |
| 206 | + print(f" skip (warmup) : {args['num_warmup']}") |
| 207 | + print(f" datatype : {args['datatype']} (torch dtype: {dtype})") |
| 208 | + print(f" element size : {elem_size} bytes") |
| 209 | + print(f" heap size : {heap_size} ({hex(heap_size)})") |
| 210 | + print(f" block size : {BLOCK_SIZE}") |
| 211 | + print(f" buffer len : {BUFFER_LEN} elements") |
| 212 | + print(f" output target : {'<terminal>' if out is None else out}") |
| 213 | + print(header) |
| 214 | + |
| 215 | + |
| 216 | +if __name__ == "__main__": |
| 217 | + args = parse_args() |
| 218 | + dtype = torch_dtype_from_str(args["datatype"]) |
| 219 | + heap_size = args["heap_size"] |
| 220 | + |
| 221 | + shmem = iris.iris(heap_size) |
| 222 | + num_ranks = shmem.get_num_ranks() |
| 223 | + heap_bases = shmem.get_heap_bases() |
| 224 | + cur_rank = shmem.get_rank() |
| 225 | + |
| 226 | + BLOCK_SIZE = args["block_size"] |
| 227 | + BUFFER_LEN = args["buffer_size"] |
| 228 | + |
| 229 | + niter = args["iter"] |
| 230 | + skip = args["num_warmup"] |
| 231 | + |
| 232 | + if cur_rank == 0: |
| 233 | + print_run_settings(args, num_ranks, dtype, BLOCK_SIZE, BUFFER_LEN) |
| 234 | + shmem.barrier() |
| 235 | + try: |
| 236 | + device_idx = torch.cuda.current_device() |
| 237 | + device_name = torch.cuda.get_device_name(device_idx) |
| 238 | + except Exception: |
| 239 | + device_name = "unknown CUDA device" |
| 240 | + print(f"[rank {cur_rank}] ready, device[{device_idx}]: {device_name}") |
| 241 | + |
| 242 | + mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") |
| 243 | + mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda") |
| 244 | + |
| 245 | + local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda") |
| 246 | + |
| 247 | + source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype) |
| 248 | + flag = shmem.ones(1, dtype=torch.int32) |
| 249 | + |
| 250 | + grid = lambda meta: (1,) |
| 251 | + for source_rank in range(num_ranks): |
| 252 | + for destination_rank in range(num_ranks): |
| 253 | + if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]: |
| 254 | + peer_for_me = destination_rank if cur_rank == source_rank else source_rank |
| 255 | + ping_pong[grid]( |
| 256 | + source_buffer, |
| 257 | + BUFFER_LEN, |
| 258 | + skip, |
| 259 | + niter, |
| 260 | + flag, |
| 261 | + cur_rank, |
| 262 | + peer_for_me, |
| 263 | + BLOCK_SIZE, |
| 264 | + heap_bases, |
| 265 | + mm_begin_timestamp, |
| 266 | + mm_end_timestamp, |
| 267 | + ) |
| 268 | + shmem.barrier() |
| 269 | + |
| 270 | + mm_begin_cpu = mm_begin_timestamp.cpu().numpy() |
| 271 | + mm_end_cpu = mm_end_timestamp.cpu().numpy() |
| 272 | + for destination_rank in range(num_ranks): |
| 273 | + delta = mm_end_cpu[destination_rank, :] - mm_begin_cpu[destination_rank, :] |
| 274 | + avg_ns = float(delta.sum() / max(1, delta.size) / max(1, niter)) |
| 275 | + local_latency[destination_rank] = avg_ns |
| 276 | + |
| 277 | + latency_matrix = mpi_allgather(local_latency.cpu()) |
| 278 | + |
| 279 | + if cur_rank == 0: |
| 280 | + save_results(latency_matrix, args["output_file"]) |
| 281 | + print("Benchmark complete.") |
0 commit comments