-
Notifications
You must be signed in to change notification settings - Fork 16
Fix atomic_add validation by correcting expected values and enhancing test coverage #177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
mawad-amd
merged 12 commits into
main
from
copilot/fix-cd6f01c8-b9a7-4dd9-a078-222e0df60796
Sep 29, 2025
+180
−35
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
e99a158
Initial plan
Copilot c34d98c
Fix atomic_add validation by storing results and resetting buffers
Copilot 27e2ca1
Remove tl.store, validate source_buffer with num_ranks expectation, a…
Copilot a82619b
Update expected validation values to use torch.ones * world_size
Copilot 81acaea
Fix failing tests by simplifying validation and test expectations
Copilot c3209f0
Test perf and correctness
mawad-amd fad50c4
Remove unneeded variables
mawad-amd 6b14a64
Fix validation
mawad-amd faed964
Remove unneded tolerance
mawad-amd 6ad6caf
Update data types
mawad-amd 2f9e1d8
Use common datatype map
mawad-amd 16f3bcf
Optimize validation by reusing diff_mask instead of recomputing allclose
Copilot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
#!/usr/bin/env python3 | ||
# SPDX-License-Identifier: MIT | ||
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. | ||
|
||
import pytest | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
import numpy as np | ||
import iris | ||
|
||
import importlib.util | ||
from pathlib import Path | ||
from examples.common.utils import torch_dtype_to_str | ||
|
||
current_dir = Path(__file__).parent | ||
file_path = (current_dir / "../../examples/04_atomic_add/atomic_add_bench.py").resolve() | ||
module_name = "atomic_add_bench" | ||
spec = importlib.util.spec_from_file_location(module_name, file_path) | ||
module = importlib.util.module_from_spec(spec) | ||
spec.loader.exec_module(module) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"dtype", | ||
[ | ||
torch.float16, | ||
torch.bfloat16, | ||
torch.float32, | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"buffer_size, heap_size", | ||
[ | ||
(20480, (1 << 33)), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"block_size", | ||
[ | ||
512, | ||
1024, | ||
], | ||
) | ||
def test_atomic_bandwidth(dtype, buffer_size, heap_size, block_size): | ||
"""Test that atomic_add benchmark runs and produces positive bandwidth.""" | ||
shmem = iris.iris(heap_size) | ||
num_ranks = shmem.get_num_ranks() | ||
|
||
element_size_bytes = torch.tensor([], dtype=dtype).element_size() | ||
n_elements = buffer_size // element_size_bytes | ||
source_buffer = shmem.arange(n_elements, dtype=dtype) | ||
|
||
shmem.barrier() | ||
|
||
args = { | ||
"datatype": torch_dtype_to_str(dtype), | ||
"block_size": block_size, | ||
"verbose": False, | ||
"validate": False, | ||
"num_experiments": 10, | ||
"num_warmup": 5, | ||
} | ||
|
||
source_rank = 0 | ||
destination_rank = 1 if num_ranks > 1 else 0 | ||
|
||
bandwidth_gbps, _ = module.run_experiment(shmem, args, source_rank, destination_rank, source_buffer) | ||
|
||
assert bandwidth_gbps > 0, f"Bandwidth should be positive, got {bandwidth_gbps}" | ||
|
||
shmem.barrier() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"dtype", | ||
[ | ||
torch.float16, | ||
torch.bfloat16, | ||
torch.float32, | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"buffer_size, heap_size", | ||
[ | ||
(20480, (1 << 33)), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"block_size", | ||
[ | ||
512, | ||
1024, | ||
], | ||
) | ||
def test_atomic_correctness(dtype, buffer_size, heap_size, block_size): | ||
"""Test that atomic_add benchmark runs and produces positive bandwidth.""" | ||
shmem = iris.iris(heap_size) | ||
num_ranks = shmem.get_num_ranks() | ||
|
||
element_size_bytes = torch.tensor([], dtype=dtype).element_size() | ||
n_elements = buffer_size // element_size_bytes | ||
source_buffer = shmem.arange(n_elements, dtype=dtype) | ||
|
||
shmem.barrier() | ||
|
||
args = { | ||
"datatype": torch_dtype_to_str(dtype), | ||
"block_size": block_size, | ||
"verbose": False, | ||
"validate": False, | ||
"num_experiments": 1, | ||
"num_warmup": 0, | ||
} | ||
|
||
source_rank = 0 | ||
destination_rank = 1 if num_ranks > 1 else 0 | ||
|
||
_, result_buffer = module.run_experiment(shmem, args, source_rank, destination_rank, source_buffer) | ||
|
||
if shmem.get_rank() == destination_rank: | ||
expected = torch.ones(n_elements, dtype=dtype, device="cuda") | ||
|
||
assert torch.allclose(result_buffer, expected), "Result buffer should be equal to expected" | ||
|
||
shmem.barrier() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.