Skip to content

[BACKEND] BF16 atomic_add support #6519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

plotfi
Copy link
Contributor

@plotfi plotfi commented Apr 17, 2025

This PR adds BF16 support for atomics, which are less precise but cheaper

BF16 accumulators have proven to be useful in the context of Split-K's where it is necessary to have cheaper atomic accumulation across two SMs (where within an SM the accumulation is handled with more precision).

BF16 Atomics are also needed some of the following AMD related work:

  • AMD buffer atomics (ie BufferAtomicRMWOp)
  • There is also a for a path to add unit tests for bf16 atomics for AMD's backend

BF16 atomics across A100, H100 and MI300 at: https://godbolt.org/z/jW3EMbxrG

@bertmaher @SamGinzburg @davidberard98

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /python/test for end-to-end tests

@SamGinzburg
Copy link
Contributor

@scxiao This is the PR to enable BF16 atomics in Triton

@bertmaher
Copy link
Collaborator

I'd go ahead and publish this for review by the core team; we've had a sufficient number of asks for bf16 atomic_add that it feels like we should go ahead and introduce this.

@plotfi plotfi force-pushed the plotfi-bf16-atom-2025 branch 2 times, most recently from bc24d8b to 20bd632 Compare April 25, 2025 19:02
@plotfi plotfi marked this pull request as ready for review April 25, 2025 19:03
@scxiao
Copy link
Contributor

scxiao commented Apr 25, 2025

I'd go ahead and publish this for review by the core team; we've had a sufficient number of asks for bf16 atomic_add that it feels like we should go ahead and introduce this.

Triton on HIP backend also received a lot of requests to support bf16 atomic ops.

@bertmaher
Copy link
Collaborator

Hi @ptillet , @ThomasRaoux , what do you all think of this? I know we've been around on bf16 atomic_add before, but we've pretty regularly been asked about it by internal users, and given it's supported by the hardware, it seems reasonable for Triton to also support it, I'd think?

@plotfi
Copy link
Contributor Author

plotfi commented Apr 28, 2025

@joviliast Feedback welcome

@joviliast
Copy link
Contributor

@joviliast Feedback welcome

@plotfi, Thanks for enabling bf16 atomics!
The only thing is, you enabled any kind of atomic operations. But only add is tested. We defensively need to verify all the operations or limit a subset of atomics kinds for bf16 operands.

@ThomasRaoux
Copy link
Collaborator

ThomasRaoux commented Apr 29, 2025

Hi @ptillet , @ThomasRaoux , what do you all think of this? I know we've been around on bf16 atomic_add before, but we've pretty regularly been asked about it by internal users, and given it's supported by the hardware, it seems reasonable for Triton to also support it, I'd think?

I agree, I'm supportive of this change and unless @ptillet has a concern we should go ahead with this.

@ThomasRaoux
Copy link
Collaborator

@joviliast Feedback welcome

@plotfi, Thanks for enabling bf16 atomics! The only thing is, you enabled any kind of atomic operations. But only add is tested. We defensively need to verify all the operations or limit a subset of atomics kinds for bf16 operands.

+1, I wonder if we can add few cases to existing test_atomic_rmw instead?

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall

@scxiao
Copy link
Contributor

scxiao commented Apr 29, 2025

@joviliast Feedback welcome

@plotfi, Thanks for enabling bf16 atomics! The only thing is, you enabled any kind of atomic operations. But only add is tested. We defensively need to verify all the operations or limit a subset of atomics kinds for bf16 operands.

Hi @plotfi, @joviliast in his PR #6418 added a few other test cases, maybe you can make the same changes here to cover more cases.

@plotfi
Copy link
Contributor Author

plotfi commented May 1, 2025

@joviliast Feedback welcome

@plotfi, Thanks for enabling bf16 atomics! The only thing is, you enabled any kind of atomic operations. But only add is tested. We defensively need to verify all the operations or limit a subset of atomics kinds for bf16 operands.

+1, I wonder if we can add few cases to existing test_atomic_rmw instead?

@joviliast @ThomasRaoux I started adding test cases, and I'm not wondering if bf16 could work with any atom operation other than an fadd?

@ThomasRaoux
Copy link
Collaborator

@joviliast Feedback welcome

@plotfi, Thanks for enabling bf16 atomics! The only thing is, you enabled any kind of atomic operations. But only add is tested. We defensively need to verify all the operations or limit a subset of atomics kinds for bf16 operands.

+1, I wonder if we can add few cases to existing test_atomic_rmw instead?

@joviliast @ThomasRaoux I started adding test cases, and I'm not wondering if bf16 could work with any atom operation other than an fadd?

the spec suggests that min/max should work as well?
https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-atom

@plotfi
Copy link
Contributor Author

plotfi commented May 1, 2025

the spec suggests that min/max should work as well? https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-atom

The PTX spec looks a little strange here in regards to max/min because it looks like it is supported for the "Atomic operation with vector type" but not the "Atomic operation with scalar type" section. Am I missing something?

Maybe this is why for LLVM codegen, it falls back to a CAS here:

https://godbolt.org/z/7Yn3snxcs

Edit: Ah Ok, I see now how Triton handles min/max for float:

# for float
# return atomic_smin(i_ptr, i_val) if val >= 0
# return atomic_umax(i_ptr, i_val) if val < 0

Currently they seem to only handle 32/64bit float. I think it makes sense to add f16/bf16 but in a separate patch?

@plotfi plotfi force-pushed the plotfi-bf16-atom-2025 branch from ca790e9 to 4bc34ef Compare May 1, 2025 06:08
@ThomasRaoux
Copy link
Collaborator

Currently they seem to only handle 32/64bit float. I think it makes sense to add f16/bf16 but in a separate patch?

fine with me

@plotfi plotfi force-pushed the plotfi-bf16-atom-2025 branch 3 times, most recently from 5e25b0d to b0f7b9e Compare May 1, 2025 19:11
@plotfi
Copy link
Contributor Author

plotfi commented May 1, 2025

@ThomasRaoux Do thinks look good enough for an approval here? I looked at some of the existing tests for fp16 but they all appear to depend on np, which does not have a bf type. I think test_bf16_atomics covers everything needed though.

@plotfi plotfi force-pushed the plotfi-bf16-atom-2025 branch from b0f7b9e to 7bdf934 Compare May 1, 2025 19:14
@ThomasRaoux
Copy link
Collaborator

@ThomasRaoux Do thinks look good enough for an approval here? I looked at some of the existing tests for fp16 but they all appear to depend on np, which does not have a bf type. I think test_bf16_atomics covers everything needed though.

any chance you can merge the test_core test with the existing atomic tests as suggested in this comment: #6519 (comment)

rest looks good to me

@plotfi plotfi force-pushed the plotfi-bf16-atom-2025 branch from 7bdf934 to 83bb638 Compare May 1, 2025 22:17
@plotfi
Copy link
Contributor Author

plotfi commented May 1, 2025

@ThomasRaoux Do thinks look good enough for an approval here? I looked at some of the existing tests for fp16 but they all appear to depend on np, which does not have a bf type. I think test_bf16_atomics covers everything needed though.

any chance you can merge the test_core test with the existing atomic tests as suggested in this comment: #6519 (comment)

rest looks good to me

Ah yeah, I did managed to get these tests working with bf16. had to work around the numpy issues as #6519 did. Are there any other tests you'd like to integrate? I will make sure the tests I enabled at the very least cover the same set that #6519 did.

@plotfi plotfi force-pushed the plotfi-bf16-atom-2025 branch from 83bb638 to 91967d9 Compare May 1, 2025 22:33
@plotfi
Copy link
Contributor Author

plotfi commented May 1, 2025

@ThomasRaoux Do thinks look good enough for an approval here? I looked at some of the existing tests for fp16 but they all appear to depend on np, which does not have a bf type. I think test_bf16_atomics covers everything needed though.

any chance you can merge the test_core test with the existing atomic tests as suggested in this comment: #6519 (comment)

rest looks good to me

Took a closer look at #6519, it seems what I did to handle the BF16 with NP differs slightly but does about the same thing. I did use FP16 for the NP code and the accuracy checks do pass this way, with one exception (where I modified the atol). Let me know if we want to keep accuracy checks when comparing np.float16 versus cuda bfloat16.

Edit: the failed AMD tests have given me my answer

@plotfi plotfi force-pushed the plotfi-bf16-atom-2025 branch 3 times, most recently from 52750aa to 9e80456 Compare May 2, 2025 07:29
Comment on lines +7463 to +7480
@pytest.mark.interpreter
@pytest.mark.skipif(not is_cuda(), reason="Not implemented for Interpreter")
def test_bf16_atomics(device):

@triton.jit
def _kernel(src0, src1, dst, dst2):
offset = tl.load(src0, None)
val = tl.load(src1, None)
old = tl.atomic_add(dst + offset, val)
tl.store(dst2, old)

acc = torch.zeros(256, dtype=torch.bfloat16, device=device)
acc2 = torch.zeros(256, dtype=torch.bfloat16, device=device)
idx = torch.randint(0, 256, (16 << 20, ), device=device)
val = torch.ones(16 << 20, dtype=torch.bfloat16, device=device)

h = _kernel[(triton.cdiv(idx.numel(), 1024), )](idx, val, acc, acc2)
assert 'atomic_rmw' in h.asm["ttir"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need that one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need that one?

You're right I think it can be dropped. Was going to double check that before I remove it+land though.

plotfi and others added 2 commits May 2, 2025 10:21
This revives triton-lang#2708 to add
support for atomics using BF16 types which are less precise but cheaper.

BF16 accumulators have proven to be useful in the context of Split-K's
where it is necessary to have cheaper atomic accumulation across two SMs.

BF16 atomics are also needed for some of the AMD buffer atomics work
(ie BufferAtomicRMWOp) as well a the need for a path to add unit tests
for AMD's backend.

BF16 atomics across A100, H100 and MI300 at:

https://godbolt.org/z/jW3EMbxrG
Comment on lines +7463 to +7464
@pytest.mark.interpreter
@pytest.mark.skipif(not is_cuda(), reason="Not implemented for Interpreter")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.interpreter
@pytest.mark.skipif(not is_cuda(), reason="Not implemented for Interpreter")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interpreter mode does not support BF16 afaict

@@ -1744,7 +1763,10 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
cnt += 1

kernel[(1, )](val, idx, dst, shape0, shape1, mask_step, 64, num_ctas=num_ctas)
np.testing.assert_allclose(to_numpy(dst_ref), to_numpy(dst), atol=1e-2)

# Do not check accuracy for bf16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you ignore the accuracy checks everywhere? What's the point of running the test at all if you don't check that it outputs the correct result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you ignore the accuracy checks everywhere? What's the point of running the test at all if you don't check that it outputs the correct result.

It's the same as @joviliast 's PR at #6418, the reason for skipping the accuracy checks is that most of them are done comparing against what is computed from NumPy, which does not support BF16. I tried reducing the level of accuracy (ie atol) to check for, but that seems to vary from machine to machine (ie NV vs AMD).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peterbell10 One thing I will try is keeping an alternative set of checks but with lower accuracy and see if those pass on AMD and NV. If this works, I will land that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you tried trunking the low mantissa bits and comparing to f32? I would expect it to be close

@plotfi plotfi force-pushed the plotfi-bf16-atom-2025 branch from 9e80456 to 8b4bdd4 Compare May 2, 2025 17:31
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]

# triton result
rs = RandomState(17)
isBF16 = (dtype_x_str == 'bfloat16')
dtype_x_str = 'float16' if isBF16 else dtype_x_str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why float16 here?

@@ -1744,7 +1763,10 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const
cnt += 1

kernel[(1, )](val, idx, dst, shape0, shape1, mask_step, 64, num_ctas=num_ctas)
np.testing.assert_allclose(to_numpy(dst_ref), to_numpy(dst), atol=1e-2)

# Do not check accuracy for bf16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you tried trunking the low mantissa bits and comparing to f32? I would expect it to be close

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants