Skip to content

[mlir][AMDGPU] Remove an old bf16 workaround #108409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,18 +671,27 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type outType = typeConverter->convertType(op.getDestD().getType());
auto outType =
typeConverter->convertType<VectorType>(op.getDestD().getType());
if (!outType)
return rewriter.notifyMatchFailure(op, "type conversion failed");

if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");

// The WMMA operations represent vectors of bf16s as vectors of i16s, so we
// need to bitcast bfloats to i16 and then bitcast them back.
VectorType rawOutType = outType;
if (outType.getElementType().isBF16())
rawOutType = outType.clone(rewriter.getI16Type());

std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);

if (!maybeIntrinsic.has_value())
return op.emitOpError("no intrinsic matching WMMA on the given chipset");

OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes(outType);
loweredOp.addTypes(rawOutType);

SmallVector<Value, 4> operands;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
Expand All @@ -694,7 +703,12 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {

loweredOp.addOperands(operands);
Operation *lowered = rewriter.create(loweredOp);
rewriter.replaceOp(op, lowered->getResults());

Operation *maybeCastBack = lowered;
if (rawOutType != outType)
maybeCastBack =
rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
rewriter.replaceOp(op, maybeCastBack->getResults());

return success();
}
Expand Down Expand Up @@ -1033,15 +1047,6 @@ struct ConvertAMDGPUToROCDLPass
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns,
Chipset chipset) {
converter.addConversion([](BFloat16Type t) -> Type {
return IntegerType::get(t.getContext(), 16);
});
converter.addConversion([&converter](VectorType t) -> std::optional<Type> {
if (!t.getElementType().isBF16())
return std::nullopt;
return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16)));
});

patterns
.add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
Expand Down
6 changes: 4 additions & 2 deletions mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
// CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
// CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32>
Expand Down
14 changes: 7 additions & 7 deletions mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -445,22 +445,22 @@ gpu.module @test_module {

// -----

// Test that the bf16 type is lowered away on this target.
// Test that the bf16 type is passed through to LLVM.

gpu.module @test_module {
// CHECK-LABEL: func @bf16_id
func.func @bf16_id(%arg0 : bf16) -> bf16 {
// CHECK-SAME: (%[[ARG0:.+]]: i16)
// CHECK-SAME: -> i16
// CHECK: return %[[ARG0]] : i16
// CHECK-SAME: (%[[ARG0:.+]]: bf16)
// CHECK-SAME: -> bf16
// CHECK: return %[[ARG0]] : bf16
func.return %arg0 : bf16
}

// CHECK-LABEL: func @bf16x4_id
func.func @bf16x4_id(%arg0 : vector<4xbf16>) -> vector<4xbf16> {
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi16>)
// CHECK-SAME: -> vector<4xi16>
// CHECK: return %[[ARG0]] : vector<4xi16>
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xbf16>)
// CHECK-SAME: -> vector<4xbf16>
// CHECK: return %[[ARG0]] : vector<4xbf16>
func.return %arg0 : vector<4xbf16>
}

Expand Down
Loading