Skip to content

Commit 46fc083

Browse files
committed
addressing comments
1 parent 5eff862 commit 46fc083

File tree

2 files changed

+281
-116
lines changed

2 files changed

+281
-116
lines changed
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: MIT
3+
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
4+
5+
import json
6+
import argparse
7+
from pathlib import Path
8+
import torch
9+
import triton
10+
import triton.language as tl
11+
import iris
12+
from iris._mpi_helpers import mpi_allgather
13+
from examples.common.utils import read_realtime
14+
15+
16+
@triton.jit()
17+
def ping_pong(
18+
data,
19+
n_elements,
20+
skip,
21+
niter,
22+
flag,
23+
curr_rank,
24+
peer_rank,
25+
BLOCK_SIZE: tl.constexpr,
26+
heap_bases: tl.tensor,
27+
mm_begin_timestamp_ptr: tl.tensor = None,
28+
mm_end_timestamp_ptr: tl.tensor = None,
29+
):
30+
pid = tl.program_id(0)
31+
block_start = pid * BLOCK_SIZE
32+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
33+
34+
data_mask = offsets < n_elements
35+
time_stmp_mask = offsets < BLOCK_SIZE
36+
flag_mask = offsets < 1
37+
38+
for i in range(niter + skip):
39+
if i == skip:
40+
start = read_realtime()
41+
tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask)
42+
first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank)
43+
token_first_done = i + 1
44+
token_second_done = i + 2
45+
if curr_rank == first_rank:
46+
iris.store(data + offsets, token_first_done, curr_rank, peer_rank, heap_bases, mask=data_mask)
47+
iris.store(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, mask=flag_mask)
48+
while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done:
49+
pass
50+
else:
51+
while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done:
52+
pass
53+
iris.store(data + offsets, token_second_done, curr_rank, peer_rank, heap_bases, mask=data_mask)
54+
iris.store(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, mask=flag_mask)
55+
56+
stop = read_realtime()
57+
tl.store(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask)
58+
59+
60+
def torch_dtype_from_str(datatype: str) -> torch.dtype:
61+
dtype_map = {
62+
"int8": torch.int8,
63+
"fp16": torch.float16,
64+
"bf16": torch.bfloat16,
65+
"fp32": torch.float32,
66+
"int32": torch.int32,
67+
}
68+
try:
69+
return dtype_map[datatype]
70+
except KeyError:
71+
raise ValueError(f"Unknown datatype: {datatype}")
72+
73+
74+
def parse_args():
75+
parser = argparse.ArgumentParser(
76+
description="Latency ping-pong benchmark",
77+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
78+
)
79+
parser.add_argument(
80+
"-t",
81+
"--datatype",
82+
type=str,
83+
default="int32",
84+
choices=["int8", "fp16", "bf16", "fp32", "int32"],
85+
help="Datatype for the message payload",
86+
)
87+
parser.add_argument(
88+
"-p",
89+
"--heap_size",
90+
type=int,
91+
default=1 << 32,
92+
help="Iris heap size",
93+
)
94+
parser.add_argument(
95+
"-b",
96+
"--block_size",
97+
type=int,
98+
default=1,
99+
help="Block size",
100+
)
101+
parser.add_argument(
102+
"-z",
103+
"--buffer_size",
104+
type=int,
105+
default=1,
106+
help="Length of the source buffer (elements)",
107+
)
108+
parser.add_argument(
109+
"-i",
110+
"--iter",
111+
type=int,
112+
default=100,
113+
help="Number of timed iterations",
114+
)
115+
parser.add_argument(
116+
"-w",
117+
"--num_warmup",
118+
type=int,
119+
default=10,
120+
help="Number of warmup (skip) iterations",
121+
)
122+
parser.add_argument(
123+
"-o",
124+
"--output_file",
125+
type=str,
126+
default=None,
127+
help="Optional output filename (if omitted, prints results to terminal). Supports .json, .csv",
128+
)
129+
return vars(parser.parse_args())
130+
131+
132+
def _pretty_print_matrix(latency_matrix: torch.Tensor) -> None:
133+
num_ranks = latency_matrix.shape[0]
134+
col_width = 12
135+
header = "SRC\\DST".ljust(col_width) + "".join(f"{j:>12}" for j in range(num_ranks))
136+
print("\nLatency matrix (ns per iter):")
137+
print(header)
138+
for i in range(num_ranks):
139+
row = f"R{i}".ljust(col_width)
140+
for j in range(num_ranks):
141+
row += f"{latency_matrix[i, j].item():12.6f}"
142+
print(row)
143+
144+
145+
def _write_csv(path: Path, latency_matrix: torch.Tensor) -> None:
146+
import csv
147+
path.parent.mkdir(parents=True, exist_ok=True)
148+
with path.open("w", newline="") as f:
149+
writer = csv.writer(f)
150+
num_ranks = latency_matrix.shape[0]
151+
writer.writerow([""] + [f"R{j}" for j in range(num_ranks)])
152+
for i in range(num_ranks):
153+
row = [f"R{i}"] + [f"{latency_matrix[i, j].item():0.6f}" for j in range(num_ranks)]
154+
writer.writerow(row)
155+
156+
157+
def _write_json(path: Path, latency_matrix: torch.Tensor) -> None:
158+
path.parent.mkdir(parents=True, exist_ok=True)
159+
num_ranks = latency_matrix.shape[0]
160+
rows = []
161+
for s in range(num_ranks):
162+
for d in range(num_ranks):
163+
rows.append(
164+
{
165+
"source_rank": int(s),
166+
"destination_rank": int(d),
167+
"latency_ns": float(latency_matrix[s, d].item()),
168+
}
169+
)
170+
with path.open("w") as f:
171+
json.dump(rows, f, indent=2)
172+
173+
174+
def save_results(latency_matrix: torch.Tensor, out: str | None) -> None:
175+
if out is None:
176+
_pretty_print_matrix(latency_matrix)
177+
return
178+
179+
path = Path(out)
180+
ext = path.suffix.lower()
181+
if ext == ".json":
182+
_write_json(path, latency_matrix)
183+
elif ext == ".csv":
184+
_write_csv(path, latency_matrix)
185+
else:
186+
raise ValueError(f"Unsupported output file extension: {out}")
187+
188+
189+
190+
def print_run_settings(
191+
args: dict,
192+
num_ranks: int,
193+
dtype: torch.dtype,
194+
BLOCK_SIZE: int,
195+
BUFFER_LEN: int,
196+
) -> None:
197+
elem_size = torch.tensor([], dtype=dtype).element_size()
198+
heap_size = args["heap_size"]
199+
out = args["output_file"]
200+
header = "=" * 72
201+
print(header)
202+
print("Latency benchmark -- run settings")
203+
print(header)
204+
print(f" num_ranks : {num_ranks}")
205+
print(f" iterations : {args['iter']} (timed)")
206+
print(f" skip (warmup) : {args['num_warmup']}")
207+
print(f" datatype : {args['datatype']} (torch dtype: {dtype})")
208+
print(f" element size : {elem_size} bytes")
209+
print(f" heap size : {heap_size} ({hex(heap_size)})")
210+
print(f" block size : {BLOCK_SIZE}")
211+
print(f" buffer len : {BUFFER_LEN} elements")
212+
print(f" output target : {'<terminal>' if out is None else out}")
213+
print(header)
214+
215+
216+
if __name__ == "__main__":
217+
args = parse_args()
218+
dtype = torch_dtype_from_str(args["datatype"])
219+
heap_size = args["heap_size"]
220+
221+
shmem = iris.iris(heap_size)
222+
num_ranks = shmem.get_num_ranks()
223+
heap_bases = shmem.get_heap_bases()
224+
cur_rank = shmem.get_rank()
225+
226+
BLOCK_SIZE = args["block_size"]
227+
BUFFER_LEN = args["buffer_size"]
228+
229+
niter = args["iter"]
230+
skip = args["num_warmup"]
231+
232+
if cur_rank == 0:
233+
print_run_settings(args, num_ranks, dtype, BLOCK_SIZE, BUFFER_LEN)
234+
shmem.barrier()
235+
try:
236+
device_idx = torch.cuda.current_device()
237+
device_name = torch.cuda.get_device_name(device_idx)
238+
except Exception:
239+
device_name = "unknown CUDA device"
240+
print(f"[rank {cur_rank}] ready, device[{device_idx}]: {device_name}")
241+
242+
mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
243+
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
244+
245+
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
246+
247+
source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype)
248+
flag = shmem.ones(1, dtype=torch.int32)
249+
250+
grid = lambda meta: (1,)
251+
for source_rank in range(num_ranks):
252+
for destination_rank in range(num_ranks):
253+
if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]:
254+
peer_for_me = destination_rank if cur_rank == source_rank else source_rank
255+
ping_pong[grid](
256+
source_buffer,
257+
BUFFER_LEN,
258+
skip,
259+
niter,
260+
flag,
261+
cur_rank,
262+
peer_for_me,
263+
BLOCK_SIZE,
264+
heap_bases,
265+
mm_begin_timestamp,
266+
mm_end_timestamp,
267+
)
268+
shmem.barrier()
269+
270+
mm_begin_cpu = mm_begin_timestamp.cpu().numpy()
271+
mm_end_cpu = mm_end_timestamp.cpu().numpy()
272+
for destination_rank in range(num_ranks):
273+
delta = mm_end_cpu[destination_rank, :] - mm_begin_cpu[destination_rank, :]
274+
avg_ns = float(delta.sum() / max(1, delta.size) / max(1, niter))
275+
local_latency[destination_rank] = avg_ns
276+
277+
latency_matrix = mpi_allgather(local_latency.cpu())
278+
279+
if cur_rank == 0:
280+
save_results(latency_matrix, args["output_file"])
281+
print("Benchmark complete.")

tests/examples/test_load_latency.py

Lines changed: 0 additions & 116 deletions
This file was deleted.

0 commit comments

Comments
 (0)