Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 148 additions & 2 deletions examples/08_gemm_atomics_all_reduce/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,134 @@ def parse_args():
return vars(parser.parse_args())


def run_gemm_all_reduce(
A,
B,
shmem,
block_m=256,
block_n=128,
block_k=64,
gsize_m=6,
two_tiles=True,
num_stages=1,
num_warps=8,
waves_per_eu=0,
mfma_instr_size=16,
kpack=2,
gemm_sms=None,
trace_tiles=False,
):
"""
Run GEMM all-reduce operation on input matrices A and B.

Args:
A: Input matrix A (M x K)
B: Input matrix B (N x K) - will be transposed internally
shmem: Iris shmem object
block_m, block_n, block_k: Block sizes for GEMM
gsize_m: Grid size M
two_tiles: Use two tiles
num_stages: Number of stages
num_warps: Number of warps
waves_per_eu: Waves per execution unit
mfma_instr_size: MFMA instruction size
kpack: K packing size
gemm_sms: Number of SMs for GEMM (defaults to half of available CUs)
trace_tiles: Enable tile tracing

Returns:
Tuple of (global_C, local_C) where global_C is the all-reduced result
"""
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

M, K = A.shape
N = B.shape[0] # B is expected to be N x K, will be transposed

# Validate matrix dimensions
assert N % world_size == 0, f"N ({N}) must be divisible by world size ({world_size})."
assert K % world_size == 0, f"K ({K}) must be divisible by world size ({world_size})."

# Transpose B if needed
if B.shape != (K, N):
B = B.T

# Set default gemm_sms if not provided
if gemm_sms is None:
gemm_sms = min(cu_count // 2, 64)

# Split matrices according to rank
rows_per_gpu = K // world_size
start_row = rank * rows_per_gpu
end_row = start_row + rows_per_gpu
local_B = B[start_row:end_row, :]
local_A = A[:, start_row:end_row]

# Create output matrices
global_C = shmem.zeros((M, N), device="cuda", dtype=A.dtype)
local_C = shmem.zeros((M, N), device="cuda", dtype=A.dtype)

# Setup parameters
total_blocks_M = triton.cdiv(M, block_m)
total_blocks_N = triton.cdiv(N, block_n)
total_tiles = total_blocks_M * total_blocks_N

# Create required tensors
tile_completed = shmem.zeros((total_tiles,), device="cuda", dtype=torch.int32)
locks = shmem.zeros((gemm_sms,), device="cuda", dtype=torch.int32)
P = shmem.zeros(
(gemm_sms, block_m * block_n),
device="cuda",
dtype=torch.float32,
)
bias = None

# Setup timestamps if tracing
timestamps = Timestamps(num_tiles=total_tiles) if trace_tiles else None

# Synchronize before computation
shmem.barrier()
iris.memset_tensor(tile_completed, 0)
shmem.barrier()

# Run the GEMM all-reduce operation
matmul.set_debug(False)
result_C = matmul.apply(
local_A,
local_B,
local_C,
global_C,
bias,
P,
locks,
tile_completed,
rank,
world_size,
gemm_sms,
block_m,
block_n,
block_k,
gsize_m,
two_tiles,
num_stages,
num_warps,
waves_per_eu,
mfma_instr_size,
kpack,
shmem.get_heap_bases(),
cu_count,
trace_tiles,
timestamps.mm_begin_timestamp if timestamps else None,
timestamps.mm_end_timestamp if timestamps else None,
)

# Synchronize after computation
shmem.barrier()

return global_C, local_C


def main():
args = parse_args()

Expand Down Expand Up @@ -239,9 +367,27 @@ def run_experiment():
if args["validate"]:
shmem.info("Validating...")

matmul.set_debug(False)
# Use the reusable function for validation
global_C_validate, _ = run_gemm_all_reduce(
A,
B,
shmem,
block_m=args["BLK_M"],
block_n=args["BLK_N"],
block_k=args["BLK_K"],
gsize_m=args["gsize_m"],
two_tiles=args["two_tiles"],
num_stages=args["num_stages"],
num_warps=args["num_warps"],
waves_per_eu=args["waves_per_eu"],
mfma_instr_size=args["mfmaInstrSize"],
kpack=args["kpack"],
gemm_sms=args["gemm_sms"],
trace_tiles=False,
)

# Validate global result
success = validate_gemm(A, B, global_C, shmem, atol=2)
success = validate_gemm(A, B, global_C_validate, shmem, atol=2)
passed_str = "passed" if success else "failed"
shmem.info(f"Final C validation {passed_str}.")

Expand Down
70 changes: 70 additions & 0 deletions tests/examples/test_gemm_atomics_all_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import importlib.util
from pathlib import Path

import pytest
import torch
import iris
from examples.common.validation import validate_gemm

# Import the benchmark module
current_dir = Path(__file__).parent
benchmark_path = (current_dir / "../../examples/08_gemm_atomics_all_reduce/benchmark.py").resolve()
spec = importlib.util.spec_from_file_location("benchmark", benchmark_path)
benchmark_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(benchmark_module)

# Test parameters
DTYPES = [torch.float16, torch.float32]
MATRIX_SIZES = [(256, 256, 256), (512, 512, 512)]
BLOCK_SIZES = [(64, 64, 32)]


@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("m, n, k", MATRIX_SIZES)
@pytest.mark.parametrize("block_m, block_n, block_k", BLOCK_SIZES)
def test_gemm_atomics_all_reduce(dtype, m, n, k, block_m, block_n, block_k):
# Initialize iris with appropriate heap size
heap_size = 1 << 30 # 1GB
shmem = iris.iris(heap_size)

rank = shmem.get_rank()
world_size = shmem.get_num_ranks()

# Skip test if matrix dimensions are not divisible by world size
if n % world_size != 0 or k % world_size != 0:
pytest.skip(f"Matrix dimensions not divisible by world size {world_size}")

# Create test matrices
A = shmem.randn(m, k, device="cuda", dtype=dtype)
B = shmem.randn(n, k, device="cuda", dtype=dtype)

# Run the GEMM all-reduce operation using the benchmark function
global_C, local_C = benchmark_module.run_gemm_all_reduce(
A,
B,
shmem,
block_m=block_m,
block_n=block_n,
block_k=block_k,
gsize_m=8,
two_tiles=True,
num_stages=4,
num_warps=4,
waves_per_eu=2,
mfma_instr_size=16,
kpack=1,
trace_tiles=False,
)

# Validate results
success = validate_gemm(A, B, global_C, shmem, atol=1e-1)
Copy link
Preview

Copilot AI Aug 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded tolerance value 1e-1 should be defined as a named constant (e.g., GEMM_VALIDATION_TOLERANCE = 1e-1) to make it clear this is a configurable parameter and easier to adjust for different precision requirements.

Copilot uses AI. Check for mistakes.


# Assert test passed
assert success, "GEMM all-reduce validation failed"

# Verify that we got a non-zero result
assert not torch.allclose(global_C, torch.zeros_like(global_C)), "Result should not be all zeros"
Loading