-
Notifications
You must be signed in to change notification settings - Fork 2k
[BACKEND] Support bf16 global atomic add on Hopper and Ampere #2708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
30e051d
870e4b6
a55b062
e999ea9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1069,9 +1069,154 @@ struct AtomicRMWOpConversion | |
converter, allocation, benefit), | ||
LoadStoreConversionBase(axisAnalysisPass) {} | ||
|
||
/// Try to match the mlir::triton::RMWOp to LLVM::AtomicBinOp. | ||
static std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp) { | ||
switch (atomicOp) { | ||
case RMWOp::AND: | ||
return LLVM::AtomicBinOp::_and; | ||
case RMWOp::OR: | ||
return LLVM::AtomicBinOp::_or; | ||
case RMWOp::XOR: | ||
return LLVM::AtomicBinOp::_xor; | ||
case RMWOp::ADD: | ||
return LLVM::AtomicBinOp::add; | ||
case RMWOp::FADD: | ||
return LLVM::AtomicBinOp::fadd; | ||
case RMWOp::MAX: | ||
return LLVM::AtomicBinOp::max; | ||
case RMWOp::MIN: | ||
return LLVM::AtomicBinOp::min; | ||
case RMWOp::UMAX: | ||
return LLVM::AtomicBinOp::umax; | ||
case RMWOp::UMIN: | ||
return LLVM::AtomicBinOp::umin; | ||
case RMWOp::XCHG: | ||
return LLVM::AtomicBinOp::xchg; | ||
default: | ||
return std::nullopt; | ||
} | ||
llvm_unreachable("Invalid RMWOp"); | ||
} | ||
|
||
LogicalResult | ||
matchAndRewriteROCm(triton::AtomicRMWOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const { | ||
auto loc = op.getLoc(); | ||
MLIRContext *ctx = rewriter.getContext(); | ||
|
||
auto atomicRmwAttr = op.getAtomicRmwOp(); | ||
Value ptr = op.getPtr(); | ||
Value val = op.getVal(); | ||
|
||
Value llPtr = adaptor.getPtr(); | ||
Value llVal = adaptor.getVal(); | ||
Value llMask = adaptor.getMask(); | ||
|
||
auto valElements = getTypeConverter()->unpackLLElements( | ||
loc, llVal, rewriter, val.getType()); | ||
auto ptrElements = getTypeConverter()->unpackLLElements( | ||
loc, llPtr, rewriter, ptr.getType()); | ||
SmallVector<Value> maskElements; | ||
if (llMask) | ||
maskElements = getTypeConverter()->unpackLLElements( | ||
loc, llMask, rewriter, op.getMask().getType()); | ||
|
||
Value opResult = op.getResult(); | ||
auto tensorTy = opResult.getType().dyn_cast<RankedTensorType>(); | ||
Type valueElemTy = | ||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) | ||
: opResult.getType(); | ||
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth(); | ||
auto elemsPerThread = getTotalElemsPerThread(val.getType()); | ||
// vec = 1, numElements = 1 for scalar | ||
auto vec = getVectorSize(ptr); | ||
int numElems = 1; | ||
// tensor | ||
if (tensorTy) { | ||
auto valTy = val.getType().cast<RankedTensorType>(); | ||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1); | ||
// mask | ||
numElems = tensorTy.getNumElements(); | ||
} | ||
Value mask = int_val(1, 1); | ||
auto tid = tid_val(); | ||
mask = and_(mask, | ||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); | ||
|
||
auto vecTy = vec_ty(valueElemTy, vec); | ||
auto retType = vec == 1 ? valueElemTy : vecTy; | ||
SmallVector<Value> resultVals(elemsPerThread); | ||
const bool f16v2 = vec == 2 && valueElemTy.isF16(); | ||
for (size_t i = 0; i < elemsPerThread; i += vec) { | ||
Value rmwPtr = ptrElements[i]; | ||
// TODO: in case llMask is zero we can create only one branch for all | ||
// elemsPerThread. | ||
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; | ||
|
||
Value undefVal = undef(retType); | ||
// Build blocks to bypass the atomic instruction for ~rmwMask. | ||
auto *curBlock = rewriter.getInsertionBlock(); | ||
auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint()); | ||
auto *atomicBlock = rewriter.createBlock( | ||
curBlock->getParent(), std::next(Region::iterator(curBlock))); | ||
endBlock->addArgument({retType}, {loc}); | ||
|
||
rewriter.setInsertionPointToEnd(curBlock); | ||
rewriter.create<LLVM::CondBrOp>(loc, rmwMask, atomicBlock, endBlock, | ||
undefVal); | ||
|
||
rewriter.setInsertionPointToEnd(atomicBlock); | ||
auto maybeKind = matchAtomicOp(atomicRmwAttr); | ||
// TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient | ||
// atomics for MI-* series of AMD GPU. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this code copy-pasted from somewhere? (I didn't review this code carefully yet.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cherry picked from the rocm fork mentioned in other comments. |
||
Value atom = rewriter.create<LLVM::AtomicRMWOp>( | ||
loc, *maybeKind, rmwPtr, valElements[i], | ||
LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult(); | ||
|
||
// NV for the f16v2 case generates one packed instruction. We have to | ||
// create two separate instructions since LLVM::AtomicRMWOp doesn't | ||
// support this. Can be optimized out with rocdl.raw.buffer.atomic. | ||
if (f16v2) { | ||
Value atom2 = rewriter.create<LLVM::AtomicRMWOp>( | ||
loc, *maybeKind, ptrElements[i+1], valElements[i + 1], | ||
LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult(); | ||
auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0)); | ||
atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult(); | ||
} | ||
rewriter.create<LLVM::BrOp>(loc, atom, endBlock); | ||
|
||
rewriter.setInsertionPointToStart(endBlock); | ||
Value retVal = endBlock->getArgument(0); | ||
if (tensorTy) { | ||
for (int ii = 0; ii < vec; ++ii) { | ||
resultVals[i + ii] = | ||
vec == 1 ? retVal | ||
: extract_element(valueElemTy, retVal, i32_val(ii)); | ||
} | ||
} else { | ||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); | ||
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3)); | ||
store(retVal, atomPtr); | ||
Value ret = load(atomPtr); | ||
rewriter.replaceOp(op, {ret}); | ||
} | ||
} | ||
if (tensorTy) { | ||
Type structTy = getTypeConverter()->convertType(tensorTy); | ||
Value resultStruct = getTypeConverter()->packLLElements( | ||
loc, resultVals, rewriter, structTy); | ||
rewriter.replaceOp(op, {resultStruct}); | ||
} | ||
return success(); | ||
} | ||
|
||
LogicalResult | ||
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
#ifdef USE_ROCM | ||
return matchAndRewriteROCm(op, adaptor, rewriter); | ||
#endif | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume this is not the final state of the patch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep. I just need to make sure I'm not breaking other use cases or tests by doing this. |
||
|
||
auto loc = op.getLoc(); | ||
MLIRContext *ctx = rewriter.getContext(); | ||
|
||
|
@@ -1099,8 +1244,21 @@ struct AtomicRMWOpConversion | |
|
||
auto valueTy = op.getResult().getType(); | ||
auto tensorTy = valueTy.dyn_cast<RankedTensorType>(); | ||
auto elementTy = tensorTy.getElementType(); | ||
const bool isHopper = | ||
moduleOp->hasAttr("triton_gpu.compute-capability") && | ||
moduleOp->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability") | ||
.getInt() >= 90; | ||
if (atomicRmwAttr == RMWOp::FADD && elementTy.isBF16() && !isHopper) { | ||
if (valElements.size() && !valElements[0].getType().isBF16()) { | ||
assert(false && "atom.add.bf16 fallback requires bf16 stored elements"); | ||
} else { | ||
return matchAndRewriteROCm(op, adaptor, rewriter); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be called rewriteUsingCAS or something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well so, based on my experimentation it only does an atomic cas because the nvptx backend seems to support only that pattern and not the more compact atom.global.add ptx (this goes for f16 and f32 as well). At the moment on ampere for bf16 there doesn't seem to be anything better I know of that we can do so the nvptx backends lowering is good enough as a fall back I think. |
||
} | ||
} | ||
|
||
Type valueElemTy = | ||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) | ||
tensorTy ? getTypeConverter()->convertType(elementTy) | ||
: valueTy; | ||
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); | ||
auto elemsPerThread = getTotalElemsPerThread(val.getType()); | ||
|
@@ -1157,7 +1315,7 @@ struct AtomicRMWOpConversion | |
case RMWOp::FADD: | ||
rmwOp = "add"; | ||
rmwOp += (valueElemNBits == 16 ? ".noftz" : ""); | ||
sTy = "f" + sBits; | ||
sTy = ((isHopper && elementTy.isBF16()) ? "bf" : "f") + sBits; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. I was being a little redundant here. |
||
sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : ""; | ||
break; | ||
case RMWOp::MAX: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,15 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter( | |
}); | ||
// Internally store bfloat16 as int16 | ||
addConversion([&](BFloat16Type type) -> std::optional<Type> { | ||
// TODO: Experimental ifdef to try storing bf16 as bf16 since LLVM does | ||
// support this type now. Needed because some irgen fails if an instruction | ||
// operand expects a bf16 and gets an i16 instead. | ||
#define STORE_BF16_AS_BF16 1 | ||
#if STORE_BF16_AS_BF16 | ||
return FloatType::getBF16(type.getContext()); | ||
#else | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would prefer we deleted this if it doesn't work with LLVM as-is today. |
||
return IntegerType::get(type.getContext(), 16); | ||
#endif | ||
}); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary, since you have a
default
case anyway?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is picked from https://github.com/ROCmSoftwarePlatform/triton, it also would be interesting to try lowering more atomics cases using generic llvm ir in the future.