Skip to content

Conversation

Copilot
Copy link
Contributor

@Copilot Copilot AI commented Sep 26, 2025

Problem

The 04_atomic_add validation was failing due to incorrect validation logic:

  1. Wrong validation target: The validation was checking result_buffer which was never written to by the kernel, since atomic_add returns the previous value but the kernel didn't store it anywhere.

  2. Incorrect expected values: The validation expected wrong values that didn't match the actual atomic operation behavior.

Root Cause

The atomic add operation modifies the target buffer by adding the specified value and returns the previous value at each memory location. The original implementation had:

  • No storage of atomic_add results (as requested by maintainer feedback)
  • Validation checking unused result_buffer instead of modified source_buffer
  • Incorrect expected value calculations

Solution

1. Fixed Validation Target and Logic

Updated validation to check source_buffer (which gets modified by atomic_add) with correct expected values:

# Validate source_buffer which contains the accumulated atomic_add results
expected = torch.ones(n_elements, dtype=dtype, device="cuda")  # Expect uniform 1s
assert torch.allclose(source_buffer, expected)

2. Enhanced Test Coverage

Separated performance and correctness testing with comprehensive test suite:

# Performance test
def test_atomic_bandwidth():
    bandwidth_gbps, _ = module.run_experiment(...)
    assert bandwidth_gbps > 0

# Correctness test  
def test_atomic_correctness():
    _, result_buffer = module.run_experiment(...)
    expected = torch.ones(n_elements, dtype=dtype, device="cuda")
    assert torch.allclose(result_buffer, expected)

3. Improved Error Handling

Added proper validation error reporting and early exit on failure:

if not success:
    dist.destroy_process_group()
    sys.exit(1)

4. Clean Return Interface

Updated to return both bandwidth and buffer state for testing:

return bandwidth_gbps, source_buffer.clone()

Testing

  • Performance Testing: test_atomic_bandwidth() verifies positive bandwidth measurements across different data types and configurations
  • Correctness Testing: test_atomic_correctness() validates that atomic operations produce expected uniform values
  • Comprehensive Coverage: Tests float16, bfloat16, and float32 data types with various buffer and block sizes
  • Proper Error Handling: Early exit on validation failures with clear error reporting
  • Utility Integration: Uses torch_dtype_to_str() from common utilities for consistent datatype handling

Expected Behavior

When validation is enabled (--validate flag):

  1. Each rank performs atomic_add operations adding 1 to elements on destination rank buffers
  2. After operations complete, source_buffer elements should equal 1 uniformly
  3. Validation checks source_buffer against expected uniform values
  4. Tests verify both bandwidth performance (> 0) and correctness of final values
  5. Benchmark logic remains unchanged for performance consistency

The fix correctly addresses the validation logic while maintaining the original benchmark behavior and providing comprehensive test coverage for both performance and correctness.

Original prompt

This section details on the original issue you should resolve

<issue_title>[Issue]: 04-atomic-add validation failed</issue_title>
<issue_description>### Problem Description

The validation test failed, and there are several observations about atomic_add_kernel:

  1. It did not update result buffer, which was checked by the validation step
  2. The expected result is torch.arange(n), but warmups will keep updating the source buffer according to warmup times

I guess the idea of the validation test should be:

  1. rank 0 atomic adds 1 to remote rank 1,2,3... only once
  2. rank 1 atomic adds 1 to remote rank 0,2,3... only once
  3. then validate rank 0,1,2... 's source buffer, should all be 1,2,3...

Please correct me if there is any misunderstanding, thanks

@triton.jit
def atomic_add_kernel(
    source_buffer,  # tl.tensor: pointer to source data
    result_buffer,  # tl.tensor: pointer to result data   ## NOTE: result buffer is not udpated
    buffer_size,  # int32: total number of elements
    source_rank: tl.constexpr,
    destination_rank: tl.constexpr,
    BLOCK_SIZE: tl.constexpr,
    heap_bases_ptr: tl.tensor,  # tl.tensor: pointer to heap bases pointers
):
    pid = tl.program_id(0)

    # Compute start index of this block
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Guard for out-of-bounds accesses
    mask = offsets < buffer_size

    # Get data from target buffer
    result = iris.atomic_add(
        source_buffer + offsets, 1, source_rank, destination_rank, heap_bases_ptr, mask=mask, sem="relaxed", scope="sys"
    )

Operating System

Ubuntu 22.04

CPU

AMD

GPU

MI300X

ROCm Version

ROCm 6.3

ROCm Component

No response

Steps to Reproduce

atomic_add_bench.py --validate --buffer_size=20480 --verbose -r 1

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response</issue_description>

<agent_instructions>Make sure the PyTest not only test the bandwidth to be higher than one but also tests the expected value.</agent_instructions>

Comments on the Issue (you are @copilot in this section)

@mawad-amd Thanks for the report! Yes, this looks wrong. Will fix it.
Fixes #176

💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
@Copilot Copilot AI changed the title [WIP] [Issue]: 04-atomic-add validation failed Fix atomic_add validation by storing results and implementing proper buffer reset Sep 26, 2025
@Copilot Copilot AI requested a review from mawad-amd September 26, 2025 03:07
Copilot finished work on behalf of mawad-amd September 26, 2025 03:07
…dd test return capability

Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Copy link
Contributor Author

Copilot AI commented Sep 26, 2025

Just as a heads up, I was blocked by some firewall rules while working on your feedback. Expand below for details.

Warning

Firewall rules blocked me from connecting to one or more addresses (expand for details)

I tried to connect to the following addresses, but was blocked by firewall rules:

  • https://api.github.com/repos/ROCm/ROCm%2Firis/languages
    • Triggering command: /home/REDACTED/work/_temp/ghcca-node/node/bin/node --enable-source-maps /home/REDACTED/work/_temp/copilot-developer-action-main/dist/index.js (http block)

If you need me to access, download, or install something from one of these locations, you can either:

@Copilot Copilot AI changed the title Fix atomic_add validation by storing results and implementing proper buffer reset Fix atomic_add validation by correcting expected values and enhancing test coverage Sep 26, 2025
Copilot finished work on behalf of mawad-amd September 26, 2025 03:57
@Copilot Copilot AI requested a review from mawad-amd September 26, 2025 03:57
Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
@Copilot Copilot AI requested a review from mawad-amd September 26, 2025 04:29
Copilot finished work on behalf of mawad-amd September 26, 2025 04:29
Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Copilot finished work on behalf of mawad-amd September 26, 2025 04:40
Signed-off-by: Muhammad Awad <MuhammadAbdelghaffar.Awad@amd.com>
Signed-off-by: Muhammad Awad <MuhammadAbdelghaffar.Awad@amd.com>
@mawad-amd
Copy link
Collaborator

@copilot, I fixed the bugs in the test. Update the PR description now and don’t make any code changes.

@mawad-amd mawad-amd marked this pull request as ready for review September 28, 2025 23:35
@Copilot Copilot AI review requested due to automatic review settings September 28, 2025 23:35
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes validation failures in the atomic_add benchmark by correcting expected test values and improving test coverage. The changes address issues with tensor value mismatches, JIT function inspection errors, and missing buffer reset logic.

Key changes:

  • Simplified test expectations to focus on functional correctness rather than complex accumulated values
  • Added proper dtype mapping utilities for torch data types
  • Updated validation logic to check source buffer instead of removed result buffer
  • Enhanced test coverage with separate bandwidth and correctness test functions

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
tests/examples/test_atomic_add_bench.py New comprehensive test suite with separate bandwidth and correctness validation
examples/common/utils.py Added dtype mapping utilities for converting between torch dtypes and string representations
examples/04_atomic_add/atomic_add_bench.py Fixed validation logic, removed result_buffer parameter, and updated expected values

Copy link
Contributor Author

Copilot AI commented Sep 28, 2025

@copilot, I fixed the bugs in the test. Update the PR description now and don’t make any code changes.

Updated the PR description to reflect the current working state after the bug fixes.

Copilot finished work on behalf of mawad-amd September 28, 2025 23:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Issue]: 04-atomic-add validation failed
2 participants