Skip to content

[BACKEND] Support bf16 global atomic add on Hopper and Ampere #2708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 160 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1069,9 +1069,154 @@ struct AtomicRMWOpConversion
converter, allocation, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}

/// Try to match the mlir::triton::RMWOp to LLVM::AtomicBinOp.
static std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp) {
switch (atomicOp) {
case RMWOp::AND:
return LLVM::AtomicBinOp::_and;
case RMWOp::OR:
return LLVM::AtomicBinOp::_or;
case RMWOp::XOR:
return LLVM::AtomicBinOp::_xor;
case RMWOp::ADD:
return LLVM::AtomicBinOp::add;
case RMWOp::FADD:
return LLVM::AtomicBinOp::fadd;
case RMWOp::MAX:
return LLVM::AtomicBinOp::max;
case RMWOp::MIN:
return LLVM::AtomicBinOp::min;
case RMWOp::UMAX:
return LLVM::AtomicBinOp::umax;
case RMWOp::UMIN:
return LLVM::AtomicBinOp::umin;
case RMWOp::XCHG:
return LLVM::AtomicBinOp::xchg;
default:
return std::nullopt;
}
llvm_unreachable("Invalid RMWOp");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary, since you have a default case anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is picked from https://github.com/ROCmSoftwarePlatform/triton, it also would be interesting to try lowering more atomics cases using generic llvm ir in the future.

}

LogicalResult
matchAndRewriteROCm(triton::AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();

auto atomicRmwAttr = op.getAtomicRmwOp();
Value ptr = op.getPtr();
Value val = op.getVal();

Value llPtr = adaptor.getPtr();
Value llVal = adaptor.getVal();
Value llMask = adaptor.getMask();

auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, val.getType());
auto ptrElements = getTypeConverter()->unpackLLElements(
loc, llPtr, rewriter, ptr.getType());
SmallVector<Value> maskElements;
if (llMask)
maskElements = getTypeConverter()->unpackLLElements(
loc, llMask, rewriter, op.getMask().getType());

Value opResult = op.getResult();
auto tensorTy = opResult.getType().dyn_cast<RankedTensorType>();
Type valueElemTy =
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
: opResult.getType();
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getTotalElemsPerThread(val.getType());
// vec = 1, numElements = 1 for scalar
auto vec = getVectorSize(ptr);
int numElems = 1;
// tensor
if (tensorTy) {
auto valTy = val.getType().cast<RankedTensorType>();
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
// mask
numElems = tensorTy.getNumElements();
}
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));

auto vecTy = vec_ty(valueElemTy, vec);
auto retType = vec == 1 ? valueElemTy : vecTy;
SmallVector<Value> resultVals(elemsPerThread);
const bool f16v2 = vec == 2 && valueElemTy.isF16();
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value rmwPtr = ptrElements[i];
// TODO: in case llMask is zero we can create only one branch for all
// elemsPerThread.
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;

Value undefVal = undef(retType);
// Build blocks to bypass the atomic instruction for ~rmwMask.
auto *curBlock = rewriter.getInsertionBlock();
auto *endBlock = curBlock->splitBlock(rewriter.getInsertionPoint());
auto *atomicBlock = rewriter.createBlock(
curBlock->getParent(), std::next(Region::iterator(curBlock)));
endBlock->addArgument({retType}, {loc});

rewriter.setInsertionPointToEnd(curBlock);
rewriter.create<LLVM::CondBrOp>(loc, rmwMask, atomicBlock, endBlock,
undefVal);

rewriter.setInsertionPointToEnd(atomicBlock);
auto maybeKind = matchAtomicOp(atomicRmwAttr);
// TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient
// atomics for MI-* series of AMD GPU.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this code copy-pasted from somewhere? (I didn't review this code carefully yet.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cherry picked from the rocm fork mentioned in other comments.

Value atom = rewriter.create<LLVM::AtomicRMWOp>(
loc, *maybeKind, rmwPtr, valElements[i],
LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult();

// NV for the f16v2 case generates one packed instruction. We have to
// create two separate instructions since LLVM::AtomicRMWOp doesn't
// support this. Can be optimized out with rocdl.raw.buffer.atomic.
if (f16v2) {
Value atom2 = rewriter.create<LLVM::AtomicRMWOp>(
loc, *maybeKind, ptrElements[i+1], valElements[i + 1],
LLVM::AtomicOrdering::monotonic, StringRef("agent")).getResult();
auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0));
atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult();
}
rewriter.create<LLVM::BrOp>(loc, atom, endBlock);

rewriter.setInsertionPointToStart(endBlock);
Value retVal = endBlock->getArgument(0);
if (tensorTy) {
for (int ii = 0; ii < vec; ++ii) {
resultVals[i + ii] =
vec == 1 ? retVal
: extract_element(valueElemTy, retVal, i32_val(ii));
}
} else {
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
atomPtr = bitcast(atomPtr, ptr_ty(valueElemTy, 3));
store(retVal, atomPtr);
Value ret = load(atomPtr);
rewriter.replaceOp(op, {ret});
}
}
if (tensorTy) {
Type structTy = getTypeConverter()->convertType(tensorTy);
Value resultStruct = getTypeConverter()->packLLElements(
loc, resultVals, rewriter, structTy);
rewriter.replaceOp(op, {resultStruct});
}
return success();
}

LogicalResult
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
#ifdef USE_ROCM
return matchAndRewriteROCm(op, adaptor, rewriter);
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is not the final state of the patch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. I just need to make sure I'm not breaking other use cases or tests by doing this.


auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();

Expand Down Expand Up @@ -1099,8 +1244,21 @@ struct AtomicRMWOpConversion

auto valueTy = op.getResult().getType();
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
auto elementTy = tensorTy.getElementType();
const bool isHopper =
moduleOp->hasAttr("triton_gpu.compute-capability") &&
moduleOp->getAttrOfType<IntegerAttr>("triton_gpu.compute-capability")
.getInt() >= 90;
if (atomicRmwAttr == RMWOp::FADD && elementTy.isBF16() && !isHopper) {
if (valElements.size() && !valElements[0].getType().isBF16()) {
assert(false && "atom.add.bf16 fallback requires bf16 stored elements");
} else {
return matchAndRewriteROCm(op, adaptor, rewriter);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be called rewriteUsingCAS or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well so, based on my experimentation it only does an atomic cas because the nvptx backend seems to support only that pattern and not the more compact atom.global.add ptx (this goes for f16 and f32 as well). At the moment on ampere for bf16 there doesn't seem to be anything better I know of that we can do so the nvptx backends lowering is good enough as a fall back I think.

}
}

Type valueElemTy =
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
tensorTy ? getTypeConverter()->convertType(elementTy)
: valueTy;
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getTotalElemsPerThread(val.getType());
Expand Down Expand Up @@ -1157,7 +1315,7 @@ struct AtomicRMWOpConversion
case RMWOp::FADD:
rmwOp = "add";
rmwOp += (valueElemNBits == 16 ? ".noftz" : "");
sTy = "f" + sBits;
sTy = ((isHopper && elementTy.isBF16()) ? "bf" : "f") + sBits;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the isHopper check here. If it's not Hopper, shouldn't it be an error (maybe caught above) to try to use bf16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I was being a little redundant here.

sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : "";
break;
case RMWOp::MAX:
Expand Down
8 changes: 8 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,15 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
});
// Internally store bfloat16 as int16
addConversion([&](BFloat16Type type) -> std::optional<Type> {
// TODO: Experimental ifdef to try storing bf16 as bf16 since LLVM does
// support this type now. Needed because some irgen fails if an instruction
// operand expects a bf16 and gets an i16 instead.
#define STORE_BF16_AS_BF16 1
#if STORE_BF16_AS_BF16
return FloatType::getBF16(type.getContext());
#else
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer we deleted this if it doesn't work with LLVM as-is today.

return IntegerType::get(type.getContext(), 16);
#endif
});
}

Expand Down
2 changes: 1 addition & 1 deletion python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor,
element_ty = ptr.type.scalar.element_ty
if element_ty is tl.float16 and op != 'add':
raise ValueError("atomic_" + op + " does not support fp16")
if element_ty in [tl.int1, tl.int8, tl.int16, tl.bfloat16]:
if element_ty in [tl.int1, tl.int8, tl.int16]:
raise ValueError("atomic_" + op + " does not support " + str(element_ty))
if ptr.type.is_block():
if mask:
Expand Down