Skip to content

Commit 606853e

Browse files
committed
addressing comments
1 parent 5eff862 commit 606853e

File tree

2 files changed

+281
-116
lines changed

2 files changed

+281
-116
lines changed
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: MIT
3+
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
4+
5+
import json
6+
import csv
7+
import argparse
8+
from pathlib import Path
9+
import torch
10+
import triton
11+
import triton.language as tl
12+
import iris
13+
from iris._mpi_helpers import mpi_allgather
14+
from examples.common.utils import read_realtime
15+
16+
17+
@triton.jit()
18+
def ping_pong(
19+
data,
20+
n_elements,
21+
skip,
22+
niter,
23+
flag,
24+
curr_rank,
25+
peer_rank,
26+
BLOCK_SIZE: tl.constexpr,
27+
heap_bases: tl.tensor,
28+
mm_begin_timestamp_ptr: tl.tensor = None,
29+
mm_end_timestamp_ptr: tl.tensor = None,
30+
):
31+
pid = tl.program_id(0)
32+
block_start = pid * BLOCK_SIZE
33+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
34+
35+
data_mask = offsets < n_elements
36+
time_stmp_mask = offsets < BLOCK_SIZE
37+
flag_mask = offsets < 1
38+
39+
for i in range(niter + skip):
40+
if i == skip:
41+
start = read_realtime()
42+
tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask)
43+
first_rank = tl.minimum(curr_rank, peer_rank) if (i % 2) == 0 else tl.maximum(curr_rank, peer_rank)
44+
token_first_done = i + 1
45+
token_second_done = i + 2
46+
if curr_rank == first_rank:
47+
iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask)
48+
iris.store(flag + offsets, token_first_done, curr_rank, peer_rank, heap_bases, mask=flag_mask)
49+
while tl.load(flag, cache_modifier=".cv", volatile=True) != token_second_done:
50+
pass
51+
else:
52+
while tl.load(flag, cache_modifier=".cv", volatile=True) != token_first_done:
53+
pass
54+
iris.store(data + offsets, i, curr_rank, peer_rank, heap_bases, mask=data_mask)
55+
iris.store(flag + offsets, token_second_done, curr_rank, peer_rank, heap_bases, mask=flag_mask)
56+
57+
stop = read_realtime()
58+
tl.store(mm_end_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, stop, time_stmp_mask)
59+
60+
61+
def torch_dtype_from_str(datatype: str) -> torch.dtype:
62+
dtype_map = {
63+
"int8": torch.int8,
64+
"fp16": torch.float16,
65+
"bf16": torch.bfloat16,
66+
"fp32": torch.float32,
67+
"int32": torch.int32,
68+
}
69+
try:
70+
return dtype_map[datatype]
71+
except KeyError:
72+
raise ValueError(f"Unknown datatype: {datatype}")
73+
74+
75+
def parse_args():
76+
parser = argparse.ArgumentParser(
77+
description="Latency ping-pong benchmark",
78+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
79+
)
80+
parser.add_argument(
81+
"-t",
82+
"--datatype",
83+
type=str,
84+
default="int32",
85+
choices=["int8", "fp16", "bf16", "fp32", "int32"],
86+
help="Datatype for the message payload",
87+
)
88+
parser.add_argument(
89+
"-p",
90+
"--heap_size",
91+
type=int,
92+
default=1 << 32,
93+
help="Iris heap size",
94+
)
95+
parser.add_argument(
96+
"-b",
97+
"--block_size",
98+
type=int,
99+
default=1,
100+
help="Block size",
101+
)
102+
parser.add_argument(
103+
"-z",
104+
"--buffer_size",
105+
type=int,
106+
default=1,
107+
help="Length of the source buffer (elements)",
108+
)
109+
parser.add_argument(
110+
"-i",
111+
"--iter",
112+
type=int,
113+
default=100,
114+
help="Number of timed iterations",
115+
)
116+
parser.add_argument(
117+
"-w",
118+
"--num_warmup",
119+
type=int,
120+
default=10,
121+
help="Number of warmup (skip) iterations",
122+
)
123+
parser.add_argument(
124+
"-o",
125+
"--output_file",
126+
type=str,
127+
default=None,
128+
help="Optional output filename (if omitted, prints results to terminal). Supports .json, .csv",
129+
)
130+
return vars(parser.parse_args())
131+
132+
133+
def _pretty_print_matrix(latency_matrix: torch.Tensor) -> None:
134+
num_ranks = latency_matrix.shape[0]
135+
col_width = 12
136+
header = "SRC\\DST".ljust(col_width) + "".join(f"{j:>12}" for j in range(num_ranks))
137+
print("\nLatency matrix (ns per iter):")
138+
print(header)
139+
for i in range(num_ranks):
140+
row = f"R{i}".ljust(col_width)
141+
for j in range(num_ranks):
142+
row += f"{latency_matrix[i, j].item():12.6f}"
143+
print(row)
144+
145+
146+
def _write_csv(path: Path, latency_matrix: torch.Tensor) -> None:
147+
path.parent.mkdir(parents=True, exist_ok=True)
148+
with path.open("w", newline="") as f:
149+
writer = csv.writer(f)
150+
num_ranks = latency_matrix.shape[0]
151+
writer.writerow([""] + [f"R{j}" for j in range(num_ranks)])
152+
for i in range(num_ranks):
153+
row = [f"R{i}"] + [f"{latency_matrix[i, j].item():0.6f}" for j in range(num_ranks)]
154+
writer.writerow(row)
155+
156+
157+
def _write_json(path: Path, latency_matrix: torch.Tensor) -> None:
158+
path.parent.mkdir(parents=True, exist_ok=True)
159+
num_ranks = latency_matrix.shape[0]
160+
rows = []
161+
for s in range(num_ranks):
162+
for d in range(num_ranks):
163+
rows.append(
164+
{
165+
"source_rank": int(s),
166+
"destination_rank": int(d),
167+
"latency_ns": float(latency_matrix[s, d].item()),
168+
}
169+
)
170+
with path.open("w") as f:
171+
json.dump(rows, f, indent=2)
172+
173+
174+
def save_results(latency_matrix: torch.Tensor, out: str | None) -> None:
175+
if out is None:
176+
_pretty_print_matrix(latency_matrix)
177+
return
178+
179+
path = Path(out)
180+
ext = path.suffix.lower()
181+
if ext == ".json":
182+
_write_json(path, latency_matrix)
183+
elif ext == ".csv":
184+
_write_csv(path, latency_matrix)
185+
else:
186+
raise ValueError(f"Unsupported output file extension: {out}")
187+
188+
189+
190+
def print_run_settings(
191+
args: dict,
192+
num_ranks: int,
193+
dtype: torch.dtype,
194+
BLOCK_SIZE: int,
195+
BUFFER_LEN: int,
196+
) -> None:
197+
elem_size = torch.tensor([], dtype=dtype).element_size()
198+
heap_size = args["heap_size"]
199+
out = args["output_file"]
200+
header = "=" * 72
201+
print(header)
202+
print("Latency benchmark -- run settings")
203+
print(header)
204+
print(f" num_ranks : {num_ranks}")
205+
print(f" iterations : {args['iter']} (timed)")
206+
print(f" skip (warmup) : {args['num_warmup']}")
207+
print(f" datatype : {args['datatype']} (torch dtype: {dtype})")
208+
print(f" element size : {elem_size} bytes")
209+
print(f" heap size : {heap_size} ({hex(heap_size)})")
210+
print(f" block size : {BLOCK_SIZE}")
211+
print(f" buffer len : {BUFFER_LEN} elements")
212+
print(f" output target : {'<terminal>' if out is None else out}")
213+
print(header)
214+
215+
216+
if __name__ == "__main__":
217+
args = parse_args()
218+
dtype = torch_dtype_from_str(args["datatype"])
219+
heap_size = args["heap_size"]
220+
221+
shmem = iris.iris(heap_size)
222+
num_ranks = shmem.get_num_ranks()
223+
heap_bases = shmem.get_heap_bases()
224+
cur_rank = shmem.get_rank()
225+
226+
BLOCK_SIZE = args["block_size"]
227+
BUFFER_LEN = args["buffer_size"]
228+
229+
niter = args["iter"]
230+
skip = args["num_warmup"]
231+
232+
if cur_rank == 0:
233+
print_run_settings(args, num_ranks, dtype, BLOCK_SIZE, BUFFER_LEN)
234+
shmem.barrier()
235+
try:
236+
device_idx = torch.cuda.current_device()
237+
device_name = torch.cuda.get_device_name(device_idx)
238+
except Exception:
239+
device_name = "unknown CUDA device"
240+
print(f"[rank {cur_rank}] ready, device[{device_idx}]: {device_name}")
241+
242+
mm_begin_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
243+
mm_end_timestamp = torch.zeros((num_ranks, BLOCK_SIZE), dtype=torch.int64, device="cuda")
244+
245+
local_latency = torch.zeros((num_ranks), dtype=torch.float32, device="cuda")
246+
247+
source_buffer = shmem.ones(BUFFER_LEN, dtype=dtype)
248+
flag = shmem.ones(1, dtype=torch.int32)
249+
250+
grid = lambda meta: (1,)
251+
for source_rank in range(num_ranks):
252+
for destination_rank in range(num_ranks):
253+
if source_rank != destination_rank and cur_rank in [source_rank, destination_rank]:
254+
peer_for_me = destination_rank if cur_rank == source_rank else source_rank
255+
ping_pong[grid](
256+
source_buffer,
257+
BUFFER_LEN,
258+
skip,
259+
niter,
260+
flag,
261+
cur_rank,
262+
peer_for_me,
263+
BLOCK_SIZE,
264+
heap_bases,
265+
mm_begin_timestamp,
266+
mm_end_timestamp,
267+
)
268+
shmem.barrier()
269+
270+
mm_begin_cpu = mm_begin_timestamp.cpu().numpy()
271+
mm_end_cpu = mm_end_timestamp.cpu().numpy()
272+
for destination_rank in range(num_ranks):
273+
delta = mm_end_cpu[destination_rank, :] - mm_begin_cpu[destination_rank, :]
274+
avg_ns = float(delta.sum() / max(1, delta.size) / max(1, niter))
275+
local_latency[destination_rank] = avg_ns
276+
277+
latency_matrix = mpi_allgather(local_latency.cpu())
278+
279+
if cur_rank == 0:
280+
save_results(latency_matrix, args["output_file"])
281+
print("Benchmark complete.")

tests/examples/test_load_latency.py

Lines changed: 0 additions & 116 deletions
This file was deleted.

0 commit comments

Comments
 (0)