Skip to content

Commit 54f2854

Browse files
committed
[BACKEND] BF16 atomic_add support
This revives triton-lang#2708 to add support for atomics using BF16 types which are less precise but cheaper. BF16 accumulators have proven to be useful in the context of Split-K's where it is necessary to have cheaper atomic accumulation across two SMs. BF16 atomics are also needed for some of the AMD buffer atomics work (ie BufferAtomicRMWOp) as well a the need for a path to add unit tests for AMD's backend. BF16 atomics across A100, H100 and MI300 at: https://godbolt.org/z/jW3EMbxrG
1 parent 2769170 commit 54f2854

File tree

4 files changed

+141
-2
lines changed

4 files changed

+141
-2
lines changed

python/test/unit/language/test_core.py

+20
Original file line numberDiff line numberDiff line change
@@ -7436,3 +7436,23 @@ def _namedtuple_float_tuple_kernel():
74367436
x, y = float('-inf'), float('inf') # noqa: F841
74377437

74387438
_namedtuple_float_tuple_kernel[(1, )]()
7439+
7440+
7441+
@pytest.mark.interpreter
7442+
@pytest.mark.skipif(not is_cuda(), reason="Not implemented for Interpreter")
7443+
def test_bf16_atomics(device):
7444+
7445+
@triton.jit
7446+
def _kernel(src0, src1, dst, dst2):
7447+
offset = tl.load(src0, None)
7448+
val = tl.load(src1, None)
7449+
old = tl.atomic_add(dst + offset, val)
7450+
tl.store(dst2, old)
7451+
7452+
acc = torch.zeros(256, dtype=torch.bfloat16, device=device)
7453+
acc2 = torch.zeros(256, dtype=torch.bfloat16, device=device)
7454+
idx = torch.randint(0, 256, (16 << 20, ), device=device)
7455+
val = torch.ones(16 << 20, dtype=torch.bfloat16, device=device)
7456+
7457+
h = _kernel[(triton.cdiv(idx.numel(), 1024), )](idx, val, acc, acc2)
7458+
assert 'atomic_rmw' in h.asm["ttir"]

python/triton/language/semantic.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,9 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
13811381
element_ty = ptr.type.scalar.element_ty
13821382
if element_ty is tl.float16 and op != 'add':
13831383
raise ValueError("atomic_" + op + " does not support fp16")
1384-
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
1384+
if element_ty is tl.bfloat16 and op != 'add':
1385+
raise ValueError("atomic_" + op + " does not support bf16")
1386+
if element_ty in [tl.int1, tl.int8, tl.int16]:
13851387
raise ValueError("atomic_" + op + " does not support " + str(element_ty))
13861388
if ptr.type.is_block():
13871389
if mask is not None:
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm | FileCheck %s
2+
3+
// CHECK: llvm.atomicrmw fadd
4+
5+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32,
6+
ttg.target = "cuda:80",
7+
"ttg.threads-per-warp" = 32 : i32} {
8+
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
9+
tt.func public @triton_(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32},
10+
%arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
11+
%arg2: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
12+
%arg3: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
13+
%true = arith.constant true
14+
%0 = tt.load %arg0 : !tt.ptr<i64>
15+
%1 = tt.load %arg1 : !tt.ptr<bf16>
16+
%2 = tt.addptr %arg2, %0 : !tt.ptr<bf16>, i64
17+
%3 = tt.atomic_rmw fadd, acq_rel, gpu, %2, %1, %true {allocation.offset = 0 : i32} : (!tt.ptr<bf16>, bf16, i1) -> bf16
18+
tt.store %arg3, %3 : !tt.ptr<bf16>
19+
tt.return
20+
}
21+
}
22+
23+
24+
// CHECK: atom.global.gpu.acq_rel.add.noftz.bf16
25+
26+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32,
27+
ttg.target = "cuda:90",
28+
"ttg.threads-per-warp" = 32 : i32} {
29+
llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
30+
tt.func public @triton_(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32},
31+
%arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
32+
%arg2: !tt.ptr<bf16> {tt.divisibility = 16 : i32},
33+
%arg3: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
34+
%true = arith.constant true
35+
%0 = tt.load %arg0 : !tt.ptr<i64>
36+
%1 = tt.load %arg1 : !tt.ptr<bf16>
37+
%2 = tt.addptr %arg2, %0 : !tt.ptr<bf16>, i64
38+
%3 = tt.atomic_rmw fadd, acq_rel, gpu, %2, %1, %true {allocation.offset = 0 : i32} : (!tt.ptr<bf16>, bf16, i1) -> bf16
39+
tt.store %arg3, %3 : !tt.ptr<bf16>
40+
tt.return
41+
}
42+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

+76-1
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,81 @@ struct AtomicRMWOpConversion
887887
continue;
888888
}
889889

890+
// Let LLVM handle compare+swap loop; branch-based pred should be fine
891+
if (valueElemTy.isBF16() && getNVIDIAComputeCapability(moduleOp) < 90) {
892+
auto llvmAtomicBinOp = matchAtomicOp(atomicRmwAttr);
893+
auto llvmAtomicMemOrdering = getMemoryOrdering(op.getSem());
894+
895+
// Create basic block and branch to handle mask
896+
auto *curBlock = rewriter.getInsertionBlock();
897+
auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
898+
auto *atomicBlock = rewriter.createBlock(
899+
curBlock->getParent(), std::next(Region::iterator(curBlock)));
900+
901+
// Enter into predicate block
902+
rewriter.setInsertionPointToEnd(curBlock);
903+
// Setup for SMEM Sync case
904+
Value atomPtr =
905+
tensorTy ? nullptr
906+
: LLVM::getSharedMemoryBase(loc, rewriter, targetInfo,
907+
op.getOperation());
908+
rewriter.create<LLVM::CondBrOp>(loc, pred, atomicBlock, endBlock);
909+
910+
// Codegen the atomic-rmw instruction(s)
911+
rewriter.setInsertionPointToEnd(atomicBlock);
912+
Value atom = rewriter
913+
.create<LLVM::AtomicRMWOp>(
914+
loc, *llvmAtomicBinOp, rmwPtr, valElements[i],
915+
*llvmAtomicMemOrdering, StringRef("agent"))
916+
.getResult();
917+
// Handle the 2 bf16 case
918+
if (packed == 2 && valueElemNBits == 16) {
919+
Value atom2 = rewriter
920+
.create<LLVM::AtomicRMWOp>(
921+
loc, *llvmAtomicBinOp, ptrElements[i + 1],
922+
valElements[i + 1], *llvmAtomicMemOrdering,
923+
StringRef("agent"))
924+
.getResult();
925+
auto vecTy = vec_ty(valueElemTy, vec);
926+
auto tmp =
927+
b.insert_element(vecTy, b.undef(vecTy), atom, b.i32_val(0));
928+
atom = b.insert_element(vecTy, tmp, atom2, b.i32_val(1)).getResult();
929+
}
930+
931+
if (tensorTy) {
932+
// Return from predicated block
933+
rewriter.create<LLVM::BrOp>(loc, endBlock);
934+
935+
// Recover values from predicated block
936+
rewriter.setInsertionPointToStart(endBlock);
937+
Value ret = atom;
938+
if (vec > 1) {
939+
for (unsigned ii = 0; ii < vec; ++ii) {
940+
resultVals[i + ii] = b.extract_val(valueElemTy, ret, ii);
941+
}
942+
} else if (packed > 1) {
943+
for (unsigned ii = 0; ii < packed; ++ii) {
944+
resultVals[i + ii] =
945+
b.extract_element(valueElemTy, ret, b.i32_val(ii));
946+
}
947+
} else {
948+
resultVals[i] = ret;
949+
}
950+
} else {
951+
// Commit values from predicated block to SMEM and return from
952+
// predicate block
953+
b.store(atom, atomPtr);
954+
rewriter.create<LLVM::BrOp>(loc, endBlock);
955+
956+
// Recover values from predicated block (from SMEM)
957+
rewriter.setInsertionPointToStart(endBlock);
958+
b.barrier();
959+
Value ret = b.load(valueElemTy, atomPtr);
960+
rewriter.replaceOp(op, {ret});
961+
}
962+
continue;
963+
}
964+
890965
std::string sTy;
891966
PTXBuilder ptxBuilderAtomicRMW;
892967
// 16-bit -> "h", 32-bit -> "r", 64-bit -> "l"
@@ -944,7 +1019,7 @@ struct AtomicRMWOpConversion
9441019
case RMWOp::FADD:
9451020
rmwOp = "add";
9461021
rmwOp += (valueElemNBits == 16 ? ".noftz" : "");
947-
sTy = "f" + sBits;
1022+
sTy = (valueElemTy.isBF16() ? "bf" : "f") + sBits;
9481023
sTy += (packed == 2 && valueElemNBits == 16) ? "x2" : "";
9491024
break;
9501025
case RMWOp::MAX:

0 commit comments

Comments
 (0)