-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][AMDGPU] Add scaled floating point conversion ops fp8 #141554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-backend-amdgpu Author: Tim Gymnich (tgymnich) Changesimplement ScaledExtPackedFp8Op and PackedScaledTrunc2xFp8Op Patch is 20.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141554.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 02308568c1ad1..301705bd1786b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -112,6 +112,33 @@ def AMDGPU_ExtPackedFp8Op :
}];
}
+def AMDGPU_ScaledExtPackedFp8Op :
+ AMDGPU_Op<"scaled_ext_packed_fp8", [Pure]>,
+ Arguments<(ins AnyTypeOf<[F8E5M2, F8E4M3FN,
+ VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>]>:$source,
+ F32:$scale,
+ ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
+ Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> {
+ let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats";
+
+ let description = [{
+ Extend and scale one or two 8-bit floats in `source[index]` to a 32-bit float or
+ two floats and return them.
+
+ This rather unusual signature arises from the fact that AMD GPUs cannot
+ easily work with sub 32-bit quantities, so the compiler intrinsics for
+ extending 8-bit floats (which are, currently, the only way to work with
+ this operation) take packed vectors of 2 such floats.
+
+ If the passed-in vector has fewer than two elements, or the input is scalar,
+ the remaining values in the <2 x i8> will be filled with
+ undefined values as needed.
+ }];
+ let assemblyFormat = [{
+ attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res)
+ }];
+}
+
def AMDGPU_PackedTrunc2xFp8Op :
AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins F32:$sourceA,
@@ -139,6 +166,35 @@ def AMDGPU_PackedTrunc2xFp8Op :
let hasVerifier = 1;
}
+def AMDGPU_PackedScaledTrunc2xFp8Op :
+ AMDGPU_Op<"packed_scaled_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
+ Arguments<(ins F32:$sourceA,
+ Optional<F32>:$sourceB,
+ F32:$scale,
+ ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
+ Optional<FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
+ let summary = "Round two floats into a packed vector of 8-bit floats";
+ let description = [{
+ Scale and round the inputs `sourceA` and `sourceB` (which is undefined if not
+ specified) into the low or high word (bottom two or top two) elements
+ of the returned vector, keeping the other two elements of `existing`
+ unchanged if present (or undefined if it was not passed in).
+
+ The reason for this odd signature is that AMD GPUs cannot easily work with
+ sub-registers, and so the conversion intrinsics (which are currently the
+ only way to work with 8-bit float types) take packed vectors of 4 8-bit
+ values.
+ }];
+ let assemblyFormat = [{
+ attr-dict $sourceA `,` ($sourceB^):(`undef`)?
+ `into` ($existing^):(`undef`)? `[` `word` $wordIndex `]`
+ `,` $scale
+ `:` type($sourceA) `to` type($res) (`into` type($existing)^)?
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_PackedStochRoundFp8Op :
AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
Arguments<(ins F32:$source,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c5094799bbef7..5fc8e370ac4c4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1148,6 +1148,19 @@ struct ExtPackedFp8OpLowering final
ConversionPatternRewriter &rewriter) const override;
};
+struct ScaledExtPackedFp8OpLowering final
+ : public ConvertOpToLLVMPattern<ScaledExtPackedFp8Op> {
+ ScaledExtPackedFp8OpLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedFp8Op>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1161,6 +1174,20 @@ struct PackedTrunc2xFp8OpLowering final
ConversionPatternRewriter &rewriter) const override;
};
+struct PackedScaledTrunc2xFp8OpLowering final
+ : public ConvertOpToLLVMPattern<PackedScaledTrunc2xFp8Op> {
+ PackedScaledTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<amdgpu::PackedScaledTrunc2xFp8Op>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(PackedScaledTrunc2xFp8Op op,
+ PackedScaledTrunc2xFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
struct PackedStochRoundFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1229,6 +1256,67 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
return success();
}
+// rocdl.cvt.scalef32.pk.f32.fp8 %source[false]: i32, %c4: f32 : vector<2xf32>
+// rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
+
+// amdgpu.scaled_ext_packed_fp8 %v[0]: f8E5M2, %scale: f32 : f8E5M2 to
+// vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]: vector<2xf8E5M2>, %scale:
+// f32 : vector<2xf8E5M2> to vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]:
+// vector<4xf8E5M2>, %scale: f32 : vector<4xf8E5M2> to vector<2xf32>
+LogicalResult ScaledExtPackedFp8OpLowering::matchAndRewrite(
+ ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ if (chipset != kGfx950)
+ return rewriter.notifyMatchFailure(
+ loc, "Scaled fp8 conversion instructions are not available on target "
+ "architecture and their emulation is not implemented");
+ Type v4i8 =
+ getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
+ Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
+ Type f32 = getTypeConverter()->convertType(op.getResult().getType());
+
+ Value source = adaptor.getSource();
+ Value scale = adaptor.getScale();
+ auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
+ auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
+ Type sourceElemType = getElementTypeOrSelf(op.getSource());
+ // Extend to a v4i8
+ if (!sourceVecType || sourceVecType.getNumElements() < 4) {
+ Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
+ if (!sourceVecType) {
+ longVec = rewriter.create<LLVM::InsertElementOp>(
+ loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ } else {
+ for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
+ Value idx = createI32Constant(rewriter, loc, i);
+ Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ longVec =
+ rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ }
+ }
+ source = longVec;
+ }
+ Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ if (resultVecType) {
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ }
+ } else {
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Bf8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Fp8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ }
+ }
+ return success();
+}
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
@@ -1266,6 +1354,46 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
return success();
}
+// rocdl.cvt.scalef32.pk.fp8.f32 %sourceA: f32, %sourceB: f32, %c0: f32 ->
+// %old[false]: vector<2xi16> : vector<2xi16>
+LogicalResult PackedScaledTrunc2xFp8OpLowering::matchAndRewrite(
+ PackedScaledTrunc2xFp8Op op, PackedScaledTrunc2xFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ if (chipset != kGfx950)
+ return rewriter.notifyMatchFailure(
+ loc, "Scaled fp8 conversion instructions are not available on target "
+ "architecture and their emulation is not implemented");
+ Type v2i16 = getTypeConverter()->convertType(
+ VectorType::get(2, rewriter.getI16Type()));
+
+ Type resultType = op.getResult().getType();
+ Type resultElemType = getElementTypeOrSelf(resultType);
+
+ Value sourceA = adaptor.getSourceA();
+ Value sourceB = adaptor.getSourceB();
+ Value scale = adaptor.getScale();
+ if (!sourceB)
+ sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
+ Value existing = adaptor.getExisting();
+ if (existing)
+ existing = rewriter.create<LLVM::BitcastOp>(loc, v2i16, existing);
+ else
+ existing = rewriter.create<LLVM::UndefOp>(loc, v2i16);
+
+ Value result;
+ if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
+ result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
+ loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
+ else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
+ result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
+ loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
+
+ result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
+ op, getTypeConverter()->convertType(resultType), result);
+ return success();
+}
+
LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -1547,7 +1675,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
- ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+ ExtPackedFp8OpLowering, ScaledExtPackedFp8OpLowering,
+ PackedTrunc2xFp8OpLowering, PackedScaledTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index a0a98a4e86721..b24a185d21180 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -54,6 +54,12 @@ LogicalResult PackedTrunc2xFp8Op::verify() {
return success();
}
+LogicalResult PackedScaledTrunc2xFp8Op::verify() {
+ if (getExisting() && getExisting().getType() != getResult().getType())
+ return emitOpError("existing values must have same type as result");
+ return success();
+}
+
LogicalResult PackedStochRoundFp8Op::verify() {
if (getExisting() && getExisting().getType() != getResult().getType())
return emitOpError("existing values must have same type as result");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir
new file mode 100644
index 0000000000000..128b8eabd76cd
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+
+// CHECK-LABEL: func @scaled_ext_scalar
+// CHECK-SAME: ([[IN:%.+]]: f8E5M2, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : f8E5M2 to i8
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.bf8 [[CAST]][0], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_scalar(%v: f8E5M2, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale: f8E5M2 to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_short_vec
+// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][1], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_short_vec(%v: vector<2xf8E4M3FN>, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<2xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_full_vec
+// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][3], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_full_vec(%v: vector<4xf8E4M3FN>, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[3], %scale : vector<4xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_2xfp8
+// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][false], [[SCALE]] : vector<2xf32>
+// CHECK: return [[EXT]]
+func.func @scaled_ext_packed_2xfp8(%v: vector<2xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<2xf8E4M3FN> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_4xfp8
+// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][true], [[SCALE]] : vector<2xf32>
+// CHECK: return [[EXT]] : vector<2xf32>
+func.func @scaled_ext_packed_4xfp8(%v: vector<4xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<4xf8E4M3FN> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @packed_scaled_trunc
+// CHECK-SAME: ([[V:%.+]]: f32, [[SCALE:%.+]]: f32)
+// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[V2]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_scaled_trunc(%v: f32, %scale: f32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, undef into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
+ func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_scaled_truncx2
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[SCALE:%.+]]: f32)
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_scaled_truncx2(%v: f32, %w: f32, %scale: f32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
+ func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_scaled_truncx2_into
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>, [[SCALE:%.+]]: f32)
+// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
+// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.bf8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING_INT]][true] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
+func.func @packed_scaled_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>, %scale: f32) -> vector<4xf8E5M2> {
+ %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into %existing[word 1], %scale : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
+ func.return %ret : vector<4xf8E5M2>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 188cfcc4eb38b..d1d56bd3b5178 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -18,6 +18,20 @@ func.func @ext_packed_fp8_v(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
func.return %ret : vector<2xf32>
}
+// CHECK-LABEL: func @scaled_ext_packed_fp8_s
+// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to f32
+func.func @scaled_ext_packed_fp8_s(%v: vector<4xf8E5M2>, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_fp8_v
+// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to vector<2xf32
+func.func @scaled_ext_packed_fp8_v(%v: vector<4xf8E5M2>, %scale: f32) -> vector<2xf32> {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
// CHECK-LABEL: func @packed_trunc_2xfp8
// CHECK: amdgpu.packed_trunc_2xfp8
func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> {
@@ -25,6 +39,13 @@ func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>,
func.return %ret : vector<4xf8E5M2FNUZ>
}
+// CHECK-LABEL: func @scaled_packed_trunc_2xfp8
+// ...
[truncated]
|
@llvm/pr-subscribers-mlir-amdgpu Author: Tim Gymnich (tgymnich) Changesimplement ScaledExtPackedFp8Op and PackedScaledTrunc2xFp8Op Patch is 20.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141554.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 02308568c1ad1..301705bd1786b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -112,6 +112,33 @@ def AMDGPU_ExtPackedFp8Op :
}];
}
+def AMDGPU_ScaledExtPackedFp8Op :
+ AMDGPU_Op<"scaled_ext_packed_fp8", [Pure]>,
+ Arguments<(ins AnyTypeOf<[F8E5M2, F8E4M3FN,
+ VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>]>:$source,
+ F32:$scale,
+ ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
+ Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> {
+ let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats";
+
+ let description = [{
+ Extend and scale one or two 8-bit floats in `source[index]` to a 32-bit float or
+ two floats and return them.
+
+ This rather unusual signature arises from the fact that AMD GPUs cannot
+ easily work with sub 32-bit quantities, so the compiler intrinsics for
+ extending 8-bit floats (which are, currently, the only way to work with
+ this operation) take packed vectors of 2 such floats.
+
+ If the passed-in vector has fewer than two elements, or the input is scalar,
+ the remaining values in the <2 x i8> will be filled with
+ undefined values as needed.
+ }];
+ let assemblyFormat = [{
+ attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res)
+ }];
+}
+
def AMDGPU_PackedTrunc2xFp8Op :
AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins F32:$sourceA,
@@ -139,6 +166,35 @@ def AMDGPU_PackedTrunc2xFp8Op :
let hasVerifier = 1;
}
+def AMDGPU_PackedScaledTrunc2xFp8Op :
+ AMDGPU_Op<"packed_scaled_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
+ Arguments<(ins F32:$sourceA,
+ Optional<F32>:$sourceB,
+ F32:$scale,
+ ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
+ Optional<FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>>:$existing)>,
+ Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
+ let summary = "Round two floats into a packed vector of 8-bit floats";
+ let description = [{
+ Scale and round the inputs `sourceA` and `sourceB` (which is undefined if not
+ specified) into the low or high word (bottom two or top two) elements
+ of the returned vector, keeping the other two elements of `existing`
+ unchanged if present (or undefined if it was not passed in).
+
+ The reason for this odd signature is that AMD GPUs cannot easily work with
+ sub-registers, and so the conversion intrinsics (which are currently the
+ only way to work with 8-bit float types) take packed vectors of 4 8-bit
+ values.
+ }];
+ let assemblyFormat = [{
+ attr-dict $sourceA `,` ($sourceB^):(`undef`)?
+ `into` ($existing^):(`undef`)? `[` `word` $wordIndex `]`
+ `,` $scale
+ `:` type($sourceA) `to` type($res) (`into` type($existing)^)?
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_PackedStochRoundFp8Op :
AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
Arguments<(ins F32:$source,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c5094799bbef7..5fc8e370ac4c4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1148,6 +1148,19 @@ struct ExtPackedFp8OpLowering final
ConversionPatternRewriter &rewriter) const override;
};
+struct ScaledExtPackedFp8OpLowering final
+ : public ConvertOpToLLVMPattern<ScaledExtPackedFp8Op> {
+ ScaledExtPackedFp8OpLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedFp8Op>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
struct PackedTrunc2xFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1161,6 +1174,20 @@ struct PackedTrunc2xFp8OpLowering final
ConversionPatternRewriter &rewriter) const override;
};
+struct PackedScaledTrunc2xFp8OpLowering final
+ : public ConvertOpToLLVMPattern<PackedScaledTrunc2xFp8Op> {
+ PackedScaledTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
+ Chipset chipset)
+ : ConvertOpToLLVMPattern<amdgpu::PackedScaledTrunc2xFp8Op>(converter),
+ chipset(chipset) {}
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(PackedScaledTrunc2xFp8Op op,
+ PackedScaledTrunc2xFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
struct PackedStochRoundFp8OpLowering final
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1229,6 +1256,67 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
}
return success();
}
+// rocdl.cvt.scalef32.pk.f32.fp8 %source[false]: i32, %c4: f32 : vector<2xf32>
+// rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
+
+// amdgpu.scaled_ext_packed_fp8 %v[0]: f8E5M2, %scale: f32 : f8E5M2 to
+// vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]: vector<2xf8E5M2>, %scale:
+// f32 : vector<2xf8E5M2> to vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]:
+// vector<4xf8E5M2>, %scale: f32 : vector<4xf8E5M2> to vector<2xf32>
+LogicalResult ScaledExtPackedFp8OpLowering::matchAndRewrite(
+ ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ if (chipset != kGfx950)
+ return rewriter.notifyMatchFailure(
+ loc, "Scaled fp8 conversion instructions are not available on target "
+ "architecture and their emulation is not implemented");
+ Type v4i8 =
+ getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
+ Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
+ Type f32 = getTypeConverter()->convertType(op.getResult().getType());
+
+ Value source = adaptor.getSource();
+ Value scale = adaptor.getScale();
+ auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
+ auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
+ Type sourceElemType = getElementTypeOrSelf(op.getSource());
+ // Extend to a v4i8
+ if (!sourceVecType || sourceVecType.getNumElements() < 4) {
+ Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
+ if (!sourceVecType) {
+ longVec = rewriter.create<LLVM::InsertElementOp>(
+ loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ } else {
+ for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
+ Value idx = createI32Constant(rewriter, loc, i);
+ Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ longVec =
+ rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ }
+ }
+ source = longVec;
+ }
+ Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ if (resultVecType) {
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ }
+ } else {
+ if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Bf8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+ rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Fp8Op>(
+ op, f32, i32Source, scale, op.getIndex());
+ }
+ }
+ return success();
+}
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
@@ -1266,6 +1354,46 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
return success();
}
+// rocdl.cvt.scalef32.pk.fp8.f32 %sourceA: f32, %sourceB: f32, %c0: f32 ->
+// %old[false]: vector<2xi16> : vector<2xi16>
+LogicalResult PackedScaledTrunc2xFp8OpLowering::matchAndRewrite(
+ PackedScaledTrunc2xFp8Op op, PackedScaledTrunc2xFp8OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ if (chipset != kGfx950)
+ return rewriter.notifyMatchFailure(
+ loc, "Scaled fp8 conversion instructions are not available on target "
+ "architecture and their emulation is not implemented");
+ Type v2i16 = getTypeConverter()->convertType(
+ VectorType::get(2, rewriter.getI16Type()));
+
+ Type resultType = op.getResult().getType();
+ Type resultElemType = getElementTypeOrSelf(resultType);
+
+ Value sourceA = adaptor.getSourceA();
+ Value sourceB = adaptor.getSourceB();
+ Value scale = adaptor.getScale();
+ if (!sourceB)
+ sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
+ Value existing = adaptor.getExisting();
+ if (existing)
+ existing = rewriter.create<LLVM::BitcastOp>(loc, v2i16, existing);
+ else
+ existing = rewriter.create<LLVM::UndefOp>(loc, v2i16);
+
+ Value result;
+ if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
+ result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
+ loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
+ else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
+ result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
+ loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
+
+ result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
+ op, getTypeConverter()->convertType(resultType), result);
+ return success();
+}
+
LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -1547,7 +1675,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
- ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+ ExtPackedFp8OpLowering, ScaledExtPackedFp8OpLowering,
+ PackedTrunc2xFp8OpLowering, PackedScaledTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index a0a98a4e86721..b24a185d21180 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -54,6 +54,12 @@ LogicalResult PackedTrunc2xFp8Op::verify() {
return success();
}
+LogicalResult PackedScaledTrunc2xFp8Op::verify() {
+ if (getExisting() && getExisting().getType() != getResult().getType())
+ return emitOpError("existing values must have same type as result");
+ return success();
+}
+
LogicalResult PackedStochRoundFp8Op::verify() {
if (getExisting() && getExisting().getType() != getResult().getType())
return emitOpError("existing values must have same type as result");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir
new file mode 100644
index 0000000000000..128b8eabd76cd
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+
+// CHECK-LABEL: func @scaled_ext_scalar
+// CHECK-SAME: ([[IN:%.+]]: f8E5M2, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : f8E5M2 to i8
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.bf8 [[CAST]][0], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_scalar(%v: f8E5M2, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale: f8E5M2 to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_short_vec
+// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][1], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_short_vec(%v: vector<2xf8E4M3FN>, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<2xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_full_vec
+// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][3], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_full_vec(%v: vector<4xf8E4M3FN>, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[3], %scale : vector<4xf8E4M3FN> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_2xfp8
+// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][false], [[SCALE]] : vector<2xf32>
+// CHECK: return [[EXT]]
+func.func @scaled_ext_packed_2xfp8(%v: vector<2xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<2xf8E4M3FN> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_4xfp8
+// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][true], [[SCALE]] : vector<2xf32>
+// CHECK: return [[EXT]] : vector<2xf32>
+func.func @scaled_ext_packed_4xfp8(%v: vector<4xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<4xf8E4M3FN> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @packed_scaled_trunc
+// CHECK-SAME: ([[V:%.+]]: f32, [[SCALE:%.+]]: f32)
+// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[V2]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_scaled_trunc(%v: f32, %scale: f32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, undef into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
+ func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_scaled_truncx2
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[SCALE:%.+]]: f32)
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_scaled_truncx2(%v: f32, %w: f32, %scale: f32) -> vector<4xf8E4M3FN> {
+ %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
+ func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_scaled_truncx2_into
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>, [[SCALE:%.+]]: f32)
+// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
+// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.bf8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING_INT]][true] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
+func.func @packed_scaled_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>, %scale: f32) -> vector<4xf8E5M2> {
+ %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into %existing[word 1], %scale : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
+ func.return %ret : vector<4xf8E5M2>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 188cfcc4eb38b..d1d56bd3b5178 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -18,6 +18,20 @@ func.func @ext_packed_fp8_v(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
func.return %ret : vector<2xf32>
}
+// CHECK-LABEL: func @scaled_ext_packed_fp8_s
+// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to f32
+func.func @scaled_ext_packed_fp8_s(%v: vector<4xf8E5M2>, %scale: f32) -> f32 {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to f32
+ func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_fp8_v
+// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to vector<2xf32
+func.func @scaled_ext_packed_fp8_v(%v: vector<4xf8E5M2>, %scale: f32) -> vector<2xf32> {
+ %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to vector<2xf32>
+ func.return %ret : vector<2xf32>
+}
+
// CHECK-LABEL: func @packed_trunc_2xfp8
// CHECK: amdgpu.packed_trunc_2xfp8
func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> {
@@ -25,6 +39,13 @@ func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>,
func.return %ret : vector<4xf8E5M2FNUZ>
}
+// CHECK-LABEL: func @scaled_packed_trunc_2xfp8
+// ...
[truncated]
|
a875a16
to
8e88589
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these can land as-is, however:
1.Would it be reasonable to app (maybe as a separate op, since the signature is different) support for the packed truncation operations that take a <2 x f16>
or <2 x bf16>
as an input?
2. On that note, making the truncation operation take a F32
and an Optional<F32>
feels a bit awkward and might make things hard to lower to. Maybe we should have the truncation operation take either a F32, or a <2 x f32>
and have the vector be split apart during lowering? (I'm not 100% certain on this?)
3. There are analogous conversions to/from packed FP4 operations (though they can't return a scalar). Given that, it might make sense to have these operations not have "fp8" in their name and letting them handle the fp4 instructions as well? (The fp6 ones probably need their own op though, since they don't have tied inputs or the like)
None of this is blocking if you want to land these now, but I wanted to flag improvements I have in mind while you have the bandwidth.
(I'm not marking approval night now just to prevent confusion. If you reply that you'd like to just land these I'll approve.)
8e88589
to
c359b94
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
c359b94
to
a8243eb
Compare
@krzysz00 I went ahead and made the operations also support fp4. I did not include fp6, since the packing behaves completely differently (e.g. no selectable indices). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment, otherwise looks good to me
def AMDGPU_ScaledExtPackedOp | ||
: AMDGPU_Op<"scaled_ext_packed", [Pure]>, | ||
Arguments<( | ||
ins AnyTypeOf<[VectorOfLengthAndType<[2, 3, 4], [F8E5M2, F8E4M3FN]>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can allow 1
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you think this should be implemented?
(1) using the non-pk instructions
(2) by leaving one of the 2 input vector elements undefined / zero (potentially inefficient).
(3) combination of the above
The non-pk instructions are missing the bf16 cases and the f16 cases have different semantics (e.g. they take an existing vector input for the result to be packed into).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm more thinking that if there isn't a scalar instruction, we'll want to pad with 0s, but if there is one then never mind.
... and can you point me at the bf16 asymmetry? I can't think of one off the top of my head
(Also, if there are non-packing instructions, we may want to only take even-sized vectors and let -arith-to-amdgpu handle the splitting into packed + non-packed instructions)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... and can you point me at the bf16 asymmetry?
There is cvt.scalef32.f16.fp8
but I could not find cvt.scalef32.bf16.fp8
.
(Also, if there are non-packing instructions, we may want to only take even-sized vectors and let -arith-to-amdgpu handle the splitting into packed + non-packed instructions)
According to the function signature cvt.scalef32.f16.fp8
still seems to be a packing instruction even though its missing the pk in the name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That conversion isn't packed in the input - it only converts on single bf8/fp8, but it does still have to update one half of a v2f16 with the result
: AMDGPU_Op<"scaled_ext_packed", [Pure]>, | ||
Arguments<( | ||
ins AnyTypeOf<[VectorOfLengthAndType<[2, 3, 4], [F8E5M2, F8E4M3FN]>, | ||
VectorOfLengthAndType<[2, 3, 4, 5, 6, 7, 8], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I suppose odd numbers work, since you can just zext
// Extend to a packedVectorType | ||
if (!sourceVecType || | ||
sourceVecType.getNumElements() < packedVecType.getNumElements()) { | ||
Value longVec = rewriter.create<LLVM::UndefOp>(loc, packedVecType); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we've hit issues with this before, I'd go with constant 0s here
a8243eb
to
7b643d0
Compare
implement
ScaledExtPackedFp8Op
andPackedScaledTrunc2xFp8Op