Skip to content

[mlir][AMDGPU] Remove an old bf16 workaround #108409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 12, 2024
Merged

Conversation

krzysz00
Copy link
Contributor

The AMDGPU backend now implements LLVM's bfloat type. Therefore, we no longer need to type convert MLIR's bf16 to i16 during lowerings to ROCDL.

As a result of this change, we discovered that, whel the code for MFMA and WMMA intrinsics was mainly prepared for this change, we were failing to bitcast the bf16 results of WMMA operations out from the i16 they're natively represented as. This commit also fixes that issue.

The AMDGPU backend now implements LLVM's `bfloat` type. Therefore, we
no longer need to type convert MLIR's `bf16` to `i16` during lowerings
to ROCDL.

As a result of this change, we discovered that, whel the code for MFMA
and WMMA intrinsics was mainly prepared for this change, we were
failing to bitcast the bf16 results of WMMA operations out from the
i16 they're natively represented as. This commit also fixes that issue.
@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2024

@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

The AMDGPU backend now implements LLVM's bfloat type. Therefore, we no longer need to type convert MLIR's bf16 to i16 during lowerings to ROCDL.

As a result of this change, we discovered that, whel the code for MFMA and WMMA intrinsics was mainly prepared for this change, we were failing to bitcast the bf16 results of WMMA operations out from the i16 they're natively represented as. This commit also fixes that issue.


Full diff: https://github.com/llvm/llvm-project/pull/108409.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+15-12)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir (+1)
  • (modified) mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir (+7-7)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c2785f34564e3b..31d35390a7e7f8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -671,18 +671,25 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    Type outType = typeConverter->convertType(op.getDestD().getType());
+    auto outType =
+        cast<VectorType>(typeConverter->convertType(op.getDestD().getType()));
 
     if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
       return op->emitOpError("WMMA only supported on gfx11 and gfx12");
 
+    // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
+    // need to bitcast bfloats to i16 and then bitcast them back.
+    VectorType rawOutType = outType;
+    if (outType.getElementType().isBF16())
+      rawOutType = outType.clone(rewriter.getI16Type());
+
     std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
 
     if (!maybeIntrinsic.has_value())
       return op.emitOpError("no intrinsic matching WMMA on the given chipset");
 
     OperationState loweredOp(loc, *maybeIntrinsic);
-    loweredOp.addTypes(outType);
+    loweredOp.addTypes(rawOutType);
 
     SmallVector<Value, 4> operands;
     wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
@@ -694,7 +701,12 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
 
     loweredOp.addOperands(operands);
     Operation *lowered = rewriter.create(loweredOp);
-    rewriter.replaceOp(op, lowered->getResults());
+
+    Operation *maybeCastBack = lowered;
+    if (rawOutType != outType)
+      maybeCastBack =
+          rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
+    rewriter.replaceOp(op, maybeCastBack->getResults());
 
     return success();
   }
@@ -1033,15 +1045,6 @@ struct ConvertAMDGPUToROCDLPass
 void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                                    RewritePatternSet &patterns,
                                                    Chipset chipset) {
-  converter.addConversion([](BFloat16Type t) -> Type {
-    return IntegerType::get(t.getContext(), 16);
-  });
-  converter.addConversion([&converter](VectorType t) -> std::optional<Type> {
-    if (!t.getElementType().isBF16())
-      return std::nullopt;
-    return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16)));
-  });
-
   patterns
       .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
            RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
index 1a4ef33db2aed5..9ca89a0babd951 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
@@ -16,6 +16,7 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
   amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
   // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
+  // CHECK-NEXT: llvm.bitcast {{.*}} : vector<16xi16> to vector<16xbf16>
   amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
   // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
   amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index 56b65beb036954..3fa9fa5e935d2e 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -445,22 +445,22 @@ gpu.module @test_module {
 
 // -----
 
-// Test that the bf16 type is lowered away on this target.
+// Test that the bf16 type is passed through to LLVM.
 
 gpu.module @test_module {
   // CHECK-LABEL: func @bf16_id
   func.func @bf16_id(%arg0 : bf16) -> bf16 {
-    // CHECK-SAME: (%[[ARG0:.+]]: i16)
-    // CHECK-SAME: -> i16
-    // CHECK: return %[[ARG0]] : i16
+    // CHECK-SAME: (%[[ARG0:.+]]: bf16)
+    // CHECK-SAME: -> bf16
+    // CHECK: return %[[ARG0]] : bf16
     func.return %arg0 : bf16
   }
 
   // CHECK-LABEL: func @bf16x4_id
   func.func @bf16x4_id(%arg0 : vector<4xbf16>) -> vector<4xbf16> {
-    // CHECK-SAME: (%[[ARG0:.+]]: vector<4xi16>)
-    // CHECK-SAME: -> vector<4xi16>
-    // CHECK: return %[[ARG0]] : vector<4xi16>
+    // CHECK-SAME: (%[[ARG0:.+]]: vector<4xbf16>)
+    // CHECK-SAME: -> vector<4xbf16>
+    // CHECK: return %[[ARG0]] : vector<4xbf16>
     func.return %arg0 : vector<4xbf16>
   }
 

Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
Copy link

github-actions bot commented Sep 12, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@krzysz00 krzysz00 merged commit 6292ea6 into llvm:main Sep 12, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants