Skip to content
Merged
67 changes: 32 additions & 35 deletions examples/04_atomic_add/atomic_add_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import torch.multiprocessing as mp
import triton
import triton.language as tl
import sys

import iris
from examples.common.utils import torch_dtype_from_str

torch.manual_seed(123)
random.seed(123)
Expand All @@ -22,7 +24,6 @@
@triton.jit
def atomic_add_kernel(
source_buffer, # tl.tensor: pointer to source data
result_buffer, # tl.tensor: pointer to result data
buffer_size, # int32: total number of elements
source_rank: tl.constexpr,
destination_rank: tl.constexpr,
Expand All @@ -43,20 +44,6 @@ def atomic_add_kernel(
)


def torch_dtype_from_str(datatype: str) -> torch.dtype:
dtype_map = {
"fp16": torch.float16,
"fp32": torch.float32,
"int8": torch.int8,
"bf16": torch.bfloat16,
}
try:
return dtype_map[datatype]
except KeyError:
print(f"Unknown datatype: {datatype}")
exit(1)


def parse_args():
parser = argparse.ArgumentParser(
description="Parse Message Passing configuration.",
Expand All @@ -67,14 +54,13 @@ def parse_args():
"--datatype",
type=str,
default="fp16",
choices=["fp16", "fp32", "int8", "bf16"],
choices=["fp16", "fp32", "bf16", "int32", "int64"],
help="Datatype of computation",
)
parser.add_argument("-z", "--buffer_size", type=int, default=1 << 32, help="Buffer Size")
parser.add_argument("-b", "--block_size", type=int, default=512, help="Block Size")
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output")
parser.add_argument("-d", "--validate", action="store_true", help="Enable validation output")

parser.add_argument("-p", "--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument("-o", "--output_file", type=str, default="", help="Output file")

Expand All @@ -85,7 +71,7 @@ def parse_args():
return vars(parser.parse_args())


def run_experiment(shmem, args, source_rank, destination_rank, source_buffer, result_buffer):
def run_experiment(shmem, args, source_rank, destination_rank, source_buffer):
dtype = torch_dtype_from_str(args["datatype"])
cur_rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
Expand All @@ -108,19 +94,25 @@ def run_atomic_add():
if cur_rank == source_rank:
atomic_add_kernel[grid](
source_buffer,
result_buffer,
n_elements,
source_rank,
destination_rank,
args["block_size"],
shmem.get_heap_bases(),
)

def preamble():
source_buffer.fill_(0)

# Warmup
run_atomic_add()
shmem.barrier()
atomic_add_ms = iris.do_bench(
run_atomic_add, shmem.barrier, n_repeat=args["num_experiments"], n_warmup=args["num_warmup"]
run_atomic_add,
barrier_fn=shmem.barrier,
preamble_fn=preamble,
n_repeat=args["num_experiments"],
n_warmup=args["num_warmup"],
)

# Subtract overhead
Expand All @@ -143,28 +135,34 @@ def run_atomic_add():
if args["verbose"]:
shmem.info("Validating output...")

expected = torch.arange(n_elements, dtype=dtype, device="cuda")
diff_mask = ~torch.isclose(result_buffer, expected, atol=1)
breaking_indices = torch.nonzero(diff_mask, as_tuple=False)
expected = torch.ones(n_elements, dtype=dtype, device="cuda")

diff_mask = ~torch.isclose(source_buffer, expected)

if not torch.allclose(result_buffer, expected, atol=1):
max_diff = (result_buffer - expected).abs().max().item()
if torch.any(diff_mask):
max_diff = (source_buffer - expected).abs().max().item()
shmem.info(f"Max absolute difference: {max_diff}")
for idx in breaking_indices:
idx = tuple(idx.tolist())
computed_val = result_buffer[idx]
expected_val = expected[idx]
shmem.error(f"Mismatch at index {idx}: C={computed_val}, expected={expected_val}")
success = False
break

first_mismatch_idx = torch.argmax(diff_mask.float()).item()
computed_val = source_buffer[first_mismatch_idx]
expected_val = expected[first_mismatch_idx]
shmem.error(f"First mismatch at index {first_mismatch_idx}: C={computed_val}, expected={expected_val}")
success = False

if success and args["verbose"]:
shmem.info("Validation successful.")
if not success and args["verbose"]:
shmem.error("Validation failed.")

success = shmem.broadcast(success, source_rank)

shmem.barrier()
return bandwidth_gbps

if not success:
dist.destroy_process_group()
sys.exit(1)

return bandwidth_gbps, source_buffer.clone()


def print_bandwidth_matrix(
Expand Down Expand Up @@ -218,11 +216,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
dtype = torch_dtype_from_str(args["datatype"])
element_size_bytes = torch.tensor([], dtype=dtype).element_size()
source_buffer = shmem.arange(args["buffer_size"] // element_size_bytes, device="cuda", dtype=dtype)
result_buffer = shmem.zeros_like(source_buffer)

for source_rank in range(num_ranks):
for destination_rank in range(num_ranks):
bandwidth_gbps = run_experiment(shmem, args, source_rank, destination_rank, source_buffer, result_buffer)
bandwidth_gbps, _ = run_experiment(shmem, args, source_rank, destination_rank, source_buffer)
bandwidth_matrix[source_rank, destination_rank] = bandwidth_gbps
shmem.barrier()

Expand Down
22 changes: 22 additions & 0 deletions examples/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,28 @@
ALL_GATHER = tl.constexpr(6)


dtype_map = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"int8": torch.int8,
"int32": torch.int32,
"int64": torch.int64,
}


def torch_dtype_from_str(datatype: str) -> torch.dtype:
try:
return dtype_map[datatype]
except KeyError:
print(f"Unknown datatype: {datatype}")
exit(1)


def torch_dtype_to_str(dtype: torch.dtype) -> str:
return list(dtype_map.keys())[list(dtype_map.values()).index(dtype)]


class JSONWriter:
def __init__(self, file_path):
self.file_path = file_path
Expand Down
126 changes: 126 additions & 0 deletions tests/examples/test_atomic_add_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import pytest
import torch
import triton
import triton.language as tl
import numpy as np
import iris

import importlib.util
from pathlib import Path
from examples.common.utils import torch_dtype_to_str

current_dir = Path(__file__).parent
file_path = (current_dir / "../../examples/04_atomic_add/atomic_add_bench.py").resolve()
module_name = "atomic_add_bench"
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)


@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"buffer_size, heap_size",
[
(20480, (1 << 33)),
],
)
@pytest.mark.parametrize(
"block_size",
[
512,
1024,
],
)
def test_atomic_bandwidth(dtype, buffer_size, heap_size, block_size):
"""Test that atomic_add benchmark runs and produces positive bandwidth."""
shmem = iris.iris(heap_size)
num_ranks = shmem.get_num_ranks()

element_size_bytes = torch.tensor([], dtype=dtype).element_size()
n_elements = buffer_size // element_size_bytes
source_buffer = shmem.arange(n_elements, dtype=dtype)

shmem.barrier()

args = {
"datatype": torch_dtype_to_str(dtype),
"block_size": block_size,
"verbose": False,
"validate": False,
"num_experiments": 10,
"num_warmup": 5,
}

source_rank = 0
destination_rank = 1 if num_ranks > 1 else 0

bandwidth_gbps, _ = module.run_experiment(shmem, args, source_rank, destination_rank, source_buffer)

assert bandwidth_gbps > 0, f"Bandwidth should be positive, got {bandwidth_gbps}"

shmem.barrier()


@pytest.mark.parametrize(
"dtype",
[
torch.float16,
torch.bfloat16,
torch.float32,
],
)
@pytest.mark.parametrize(
"buffer_size, heap_size",
[
(20480, (1 << 33)),
],
)
@pytest.mark.parametrize(
"block_size",
[
512,
1024,
],
)
def test_atomic_correctness(dtype, buffer_size, heap_size, block_size):
"""Test that atomic_add benchmark runs and produces positive bandwidth."""
shmem = iris.iris(heap_size)
num_ranks = shmem.get_num_ranks()

element_size_bytes = torch.tensor([], dtype=dtype).element_size()
n_elements = buffer_size // element_size_bytes
source_buffer = shmem.arange(n_elements, dtype=dtype)

shmem.barrier()

args = {
"datatype": torch_dtype_to_str(dtype),
"block_size": block_size,
"verbose": False,
"validate": False,
"num_experiments": 1,
"num_warmup": 0,
}

source_rank = 0
destination_rank = 1 if num_ranks > 1 else 0

_, result_buffer = module.run_experiment(shmem, args, source_rank, destination_rank, source_buffer)

if shmem.get_rank() == destination_rank:
expected = torch.ones(n_elements, dtype=dtype, device="cuda")

assert torch.allclose(result_buffer, expected), "Result buffer should be equal to expected"

shmem.barrier()
Loading