@@ -887,6 +887,81 @@ struct AtomicRMWOpConversion
887
887
continue ;
888
888
}
889
889
890
+ // Let LLVM handle compare+swap loop; branch-based pred should be fine
891
+ if (valueElemTy.isBF16 () && getNVIDIAComputeCapability (moduleOp) < 90 ) {
892
+ auto llvmAtomicBinOp = matchAtomicOp (atomicRmwAttr);
893
+ auto llvmAtomicMemOrdering = getMemoryOrdering (op.getSem ());
894
+
895
+ // Create basic block and branch to handle mask
896
+ auto *curBlock = rewriter.getInsertionBlock ();
897
+ auto *endBlock = curBlock->splitBlock (rewriter.getInsertionPoint ());
898
+ auto *atomicBlock = rewriter.createBlock (
899
+ curBlock->getParent (), std::next (Region::iterator (curBlock)));
900
+
901
+ // Enter into predicate block
902
+ rewriter.setInsertionPointToEnd (curBlock);
903
+ // Setup for SMEM Sync case
904
+ Value atomPtr =
905
+ tensorTy ? nullptr
906
+ : LLVM::getSharedMemoryBase (loc, rewriter, targetInfo,
907
+ op.getOperation ());
908
+ rewriter.create <LLVM::CondBrOp>(loc, pred, atomicBlock, endBlock);
909
+
910
+ // Codegen the atomic-rmw instruction(s)
911
+ rewriter.setInsertionPointToEnd (atomicBlock);
912
+ Value atom = rewriter
913
+ .create <LLVM::AtomicRMWOp>(
914
+ loc, *llvmAtomicBinOp, rmwPtr, valElements[i],
915
+ *llvmAtomicMemOrdering, StringRef (" agent" ))
916
+ .getResult ();
917
+ // Handle the 2 bf16 case
918
+ if (packed == 2 && valueElemNBits == 16 ) {
919
+ Value atom2 = rewriter
920
+ .create <LLVM::AtomicRMWOp>(
921
+ loc, *llvmAtomicBinOp, ptrElements[i + 1 ],
922
+ valElements[i + 1 ], *llvmAtomicMemOrdering,
923
+ StringRef (" agent" ))
924
+ .getResult ();
925
+ auto vecTy = vec_ty (valueElemTy, vec);
926
+ auto tmp =
927
+ b.insert_element (vecTy, b.undef (vecTy), atom, b.i32_val (0 ));
928
+ atom = b.insert_element (vecTy, tmp, atom2, b.i32_val (1 )).getResult ();
929
+ }
930
+
931
+ if (tensorTy) {
932
+ // Return from predicated block
933
+ rewriter.create <LLVM::BrOp>(loc, endBlock);
934
+
935
+ // Recover values from predicated block
936
+ rewriter.setInsertionPointToStart (endBlock);
937
+ Value ret = atom;
938
+ if (vec > 1 ) {
939
+ for (unsigned ii = 0 ; ii < vec; ++ii) {
940
+ resultVals[i + ii] = b.extract_val (valueElemTy, ret, ii);
941
+ }
942
+ } else if (packed > 1 ) {
943
+ for (unsigned ii = 0 ; ii < packed; ++ii) {
944
+ resultVals[i + ii] =
945
+ b.extract_element (valueElemTy, ret, b.i32_val (ii));
946
+ }
947
+ } else {
948
+ resultVals[i] = ret;
949
+ }
950
+ } else {
951
+ // Commit values from predicated block to SMEM and return from
952
+ // predicate block
953
+ b.store (atom, atomPtr);
954
+ rewriter.create <LLVM::BrOp>(loc, endBlock);
955
+
956
+ // Recover values from predicated block (from SMEM)
957
+ rewriter.setInsertionPointToStart (endBlock);
958
+ b.barrier ();
959
+ Value ret = b.load (valueElemTy, atomPtr);
960
+ rewriter.replaceOp (op, {ret});
961
+ }
962
+ continue ;
963
+ }
964
+
890
965
std::string sTy ;
891
966
PTXBuilder ptxBuilderAtomicRMW;
892
967
// 16-bit -> "h", 32-bit -> "r", 64-bit -> "l"
@@ -944,7 +1019,7 @@ struct AtomicRMWOpConversion
944
1019
case RMWOp::FADD:
945
1020
rmwOp = " add" ;
946
1021
rmwOp += (valueElemNBits == 16 ? " .noftz" : " " );
947
- sTy = " f " + sBits ;
1022
+ sTy = (valueElemTy. isBF16 () ? " bf " : " f " ) + sBits ;
948
1023
sTy += (packed == 2 && valueElemNBits == 16 ) ? " x2" : " " ;
949
1024
break ;
950
1025
case RMWOp::MAX:
0 commit comments