-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][AMDGPU] Remove an old bf16 workaround #108409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The AMDGPU backend now implements LLVM's `bfloat` type. Therefore, we no longer need to type convert MLIR's `bf16` to `i16` during lowerings to ROCDL. As a result of this change, we discovered that, whel the code for MFMA and WMMA intrinsics was mainly prepared for this change, we were failing to bitcast the bf16 results of WMMA operations out from the i16 they're natively represented as. This commit also fixes that issue.
@llvm/pr-subscribers-backend-amdgpu @llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesThe AMDGPU backend now implements LLVM's As a result of this change, we discovered that, whel the code for MFMA and WMMA intrinsics was mainly prepared for this change, we were failing to bitcast the bf16 results of WMMA operations out from the i16 they're natively represented as. This commit also fixes that issue. Full diff: https://github.com/llvm/llvm-project/pull/108409.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c2785f34564e3b..31d35390a7e7f8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -671,18 +671,25 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- Type outType = typeConverter->convertType(op.getDestD().getType());
+ auto outType =
+ cast<VectorType>(typeConverter->convertType(op.getDestD().getType()));
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
+ // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
+ // need to bitcast bfloats to i16 and then bitcast them back.
+ VectorType rawOutType = outType;
+ if (outType.getElementType().isBF16())
+ rawOutType = outType.clone(rewriter.getI16Type());
+
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
if (!maybeIntrinsic.has_value())
return op.emitOpError("no intrinsic matching WMMA on the given chipset");
OperationState loweredOp(loc, *maybeIntrinsic);
- loweredOp.addTypes(outType);
+ loweredOp.addTypes(rawOutType);
SmallVector<Value, 4> operands;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
@@ -694,7 +701,12 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
loweredOp.addOperands(operands);
Operation *lowered = rewriter.create(loweredOp);
- rewriter.replaceOp(op, lowered->getResults());
+
+ Operation *maybeCastBack = lowered;
+ if (rawOutType != outType)
+ maybeCastBack =
+ rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
+ rewriter.replaceOp(op, maybeCastBack->getResults());
return success();
}
@@ -1033,15 +1045,6 @@ struct ConvertAMDGPUToROCDLPass
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
Chipset chipset) {
- converter.addConversion([](BFloat16Type t) -> Type {
- return IntegerType::get(t.getContext(), 16);
- });
- converter.addConversion([&converter](VectorType t) -> std::optional<Type> {
- if (!t.getElementType().isBF16())
- return std::nullopt;
- return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16)));
- });
-
patterns
.add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
index 1a4ef33db2aed5..9ca89a0babd951 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
@@ -16,6 +16,7 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+ // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi16> to vector<16xbf16>
amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 56b65beb036954..3fa9fa5e935d2e 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -445,22 +445,22 @@ gpu.module @test_module {
// -----
-// Test that the bf16 type is lowered away on this target.
+// Test that the bf16 type is passed through to LLVM.
gpu.module @test_module {
// CHECK-LABEL: func @bf16_id
func.func @bf16_id(%arg0 : bf16) -> bf16 {
- // CHECK-SAME: (%[[ARG0:.+]]: i16)
- // CHECK-SAME: -> i16
- // CHECK: return %[[ARG0]] : i16
+ // CHECK-SAME: (%[[ARG0:.+]]: bf16)
+ // CHECK-SAME: -> bf16
+ // CHECK: return %[[ARG0]] : bf16
func.return %arg0 : bf16
}
// CHECK-LABEL: func @bf16x4_id
func.func @bf16x4_id(%arg0 : vector<4xbf16>) -> vector<4xbf16> {
- // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi16>)
- // CHECK-SAME: -> vector<4xi16>
- // CHECK: return %[[ARG0]] : vector<4xi16>
+ // CHECK-SAME: (%[[ARG0:.+]]: vector<4xbf16>)
+ // CHECK-SAME: -> vector<4xbf16>
+ // CHECK: return %[[ARG0]] : vector<4xbf16>
func.return %arg0 : vector<4xbf16>
}
|
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
✅ With the latest revision this PR passed the C/C++ code formatter. |
The AMDGPU backend now implements LLVM's
bfloat
type. Therefore, we no longer need to type convert MLIR'sbf16
toi16
during lowerings to ROCDL.As a result of this change, we discovered that, whel the code for MFMA and WMMA intrinsics was mainly prepared for this change, we were failing to bitcast the bf16 results of WMMA operations out from the i16 they're natively represented as. This commit also fixes that issue.