-
Notifications
You must be signed in to change notification settings - Fork 2k
[BACKEND] BF16 atomic_add support #6519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@scxiao This is the PR to enable BF16 atomics in Triton |
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
I'd go ahead and publish this for review by the core team; we've had a sufficient number of asks for bf16 atomic_add that it feels like we should go ahead and introduce this. |
third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
bc24d8b
to
20bd632
Compare
Triton on HIP backend also received a lot of requests to support bf16 atomic ops. |
20bd632
to
ca790e9
Compare
Hi @ptillet , @ThomasRaoux , what do you all think of this? I know we've been around on bf16 atomic_add before, but we've pretty regularly been asked about it by internal users, and given it's supported by the hardware, it seems reasonable for Triton to also support it, I'd think? |
@joviliast Feedback welcome |
@plotfi, Thanks for enabling bf16 atomics! |
I agree, I'm supportive of this change and unless @ptillet has a concern we should go ahead with this. |
+1, I wonder if we can add few cases to existing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall
Hi @plotfi, @joviliast in his PR #6418 added a few other test cases, maybe you can make the same changes here to cover more cases. |
@joviliast @ThomasRaoux I started adding test cases, and I'm not wondering if bf16 could work with any atom operation other than an fadd? |
the spec suggests that min/max should work as well? |
The PTX spec looks a little strange here in regards to max/min because it looks like it is supported for the "Atomic operation with vector type" but not the "Atomic operation with scalar type" section. Am I missing something? Maybe this is why for LLVM codegen, it falls back to a CAS here: https://godbolt.org/z/7Yn3snxcs Edit: Ah Ok, I see now how Triton handles min/max for float: triton/python/triton/language/semantic.py Lines 1460 to 1462 in c7fc1e3
Currently they seem to only handle 32/64bit float. I think it makes sense to add f16/bf16 but in a separate patch? |
ca790e9
to
4bc34ef
Compare
fine with me |
5e25b0d
to
b0f7b9e
Compare
@ThomasRaoux Do thinks look good enough for an approval here? I looked at some of the existing tests for fp16 but they all appear to depend on np, which does not have a bf type. I think test_bf16_atomics covers everything needed though. |
b0f7b9e
to
7bdf934
Compare
any chance you can merge the test_core test with the existing atomic tests as suggested in this comment: #6519 (comment) rest looks good to me |
7bdf934
to
83bb638
Compare
Ah yeah, I did managed to get these tests working with bf16. had to work around the numpy issues as #6519 did. Are there any other tests you'd like to integrate? I will make sure the tests I enabled at the very least cover the same set that #6519 did. |
83bb638
to
91967d9
Compare
Took a closer look at #6519, it seems what I did to handle the BF16 with NP differs slightly but does about the same thing. I did use FP16 for the NP code and the accuracy checks do pass this way, with one exception (where I modified the atol). Let me know if we want to keep accuracy checks when comparing np.float16 versus cuda bfloat16. Edit: the failed AMD tests have given me my answer |
52750aa
to
9e80456
Compare
@pytest.mark.interpreter | ||
@pytest.mark.skipif(not is_cuda(), reason="Not implemented for Interpreter") | ||
def test_bf16_atomics(device): | ||
|
||
@triton.jit | ||
def _kernel(src0, src1, dst, dst2): | ||
offset = tl.load(src0, None) | ||
val = tl.load(src1, None) | ||
old = tl.atomic_add(dst + offset, val) | ||
tl.store(dst2, old) | ||
|
||
acc = torch.zeros(256, dtype=torch.bfloat16, device=device) | ||
acc2 = torch.zeros(256, dtype=torch.bfloat16, device=device) | ||
idx = torch.randint(0, 256, (16 << 20, ), device=device) | ||
val = torch.ones(16 << 20, dtype=torch.bfloat16, device=device) | ||
|
||
h = _kernel[(triton.cdiv(idx.numel(), 1024), )](idx, val, acc, acc2) | ||
assert 'atomic_rmw' in h.asm["ttir"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still need that one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still need that one?
You're right I think it can be dropped. Was going to double check that before I remove it+land though.
Making these functions more reusable
This revives triton-lang#2708 to add support for atomics using BF16 types which are less precise but cheaper. BF16 accumulators have proven to be useful in the context of Split-K's where it is necessary to have cheaper atomic accumulation across two SMs. BF16 atomics are also needed for some of the AMD buffer atomics work (ie BufferAtomicRMWOp) as well a the need for a path to add unit tests for AMD's backend. BF16 atomics across A100, H100 and MI300 at: https://godbolt.org/z/jW3EMbxrG
@pytest.mark.interpreter | ||
@pytest.mark.skipif(not is_cuda(), reason="Not implemented for Interpreter") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pytest.mark.interpreter | |
@pytest.mark.skipif(not is_cuda(), reason="Not implemented for Interpreter") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interpreter mode does not support BF16 afaict
@@ -1744,7 +1763,10 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const | |||
cnt += 1 | |||
|
|||
kernel[(1, )](val, idx, dst, shape0, shape1, mask_step, 64, num_ctas=num_ctas) | |||
np.testing.assert_allclose(to_numpy(dst_ref), to_numpy(dst), atol=1e-2) | |||
|
|||
# Do not check accuracy for bf16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you ignore the accuracy checks everywhere? What's the point of running the test at all if you don't check that it outputs the correct result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you ignore the accuracy checks everywhere? What's the point of running the test at all if you don't check that it outputs the correct result.
It's the same as @joviliast 's PR at #6418, the reason for skipping the accuracy checks is that most of them are done comparing against what is computed from NumPy, which does not support BF16. I tried reducing the level of accuracy (ie atol) to check for, but that seems to vary from machine to machine (ie NV vs AMD).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@peterbell10 One thing I will try is keeping an alternative set of checks but with lower accuracy and see if those pass on AMD and NV. If this works, I will land that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you tried trunking the low mantissa bits and comparing to f32? I would expect it to be close
9e80456
to
8b4bdd4
Compare
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op] | ||
|
||
# triton result | ||
rs = RandomState(17) | ||
isBF16 = (dtype_x_str == 'bfloat16') | ||
dtype_x_str = 'float16' if isBF16 else dtype_x_str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why float16 here?
@@ -1744,7 +1763,10 @@ def kernel(in_ptr, idx_ptr, out_ptr, shape0, shape1, mask_step, XBLOCK: tl.const | |||
cnt += 1 | |||
|
|||
kernel[(1, )](val, idx, dst, shape0, shape1, mask_step, 64, num_ctas=num_ctas) | |||
np.testing.assert_allclose(to_numpy(dst_ref), to_numpy(dst), atol=1e-2) | |||
|
|||
# Do not check accuracy for bf16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you tried trunking the low mantissa bits and comparing to f32? I would expect it to be close
This PR adds BF16 support for atomics, which are less precise but cheaper
BF16 accumulators have proven to be useful in the context of Split-K's where it is necessary to have cheaper atomic accumulation across two SMs (where within an SM the accumulation is handled with more precision).
BF16 Atomics are also needed some of the following AMD related work:
BF16 atomics across A100, H100 and MI300 at: https://godbolt.org/z/jW3EMbxrG
@bertmaher @SamGinzburg @davidberard98
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/python/test
for end-to-end tests