Skip to content

[mlir][AMDGPU] Add scaled floating point conversion ops fp8 #141554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

tgymnich
Copy link
Member

@tgymnich tgymnich commented May 27, 2025

implement ScaledExtPackedFp8Op and PackedScaledTrunc2xFp8Op

@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-backend-amdgpu

Author: Tim Gymnich (tgymnich)

Changes

implement ScaledExtPackedFp8Op and PackedScaledTrunc2xFp8Op


Patch is 20.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141554.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+56)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+130-1)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+6)
  • (added) mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir (+108)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+21)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 02308568c1ad1..301705bd1786b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -112,6 +112,33 @@ def AMDGPU_ExtPackedFp8Op :
   }];
 }
 
+def AMDGPU_ScaledExtPackedFp8Op :
+    AMDGPU_Op<"scaled_ext_packed_fp8", [Pure]>,
+    Arguments<(ins AnyTypeOf<[F8E5M2, F8E4M3FN,
+        VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>]>:$source, 
+      F32:$scale,
+      ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
+    Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> {
+  let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats";
+
+  let description = [{
+    Extend and scale one or two 8-bit floats in `source[index]` to a 32-bit float or
+    two floats and return them.
+
+    This rather unusual signature arises from the fact that AMD GPUs cannot
+    easily work with sub 32-bit quantities, so the compiler intrinsics for
+    extending 8-bit floats (which are, currently, the only way to work with
+    this operation) take packed vectors of 2 such floats.
+
+    If the passed-in vector has fewer than two elements, or the input is scalar,
+    the remaining values in the <2 x i8> will be filled with
+    undefined values as needed.
+  }];
+  let assemblyFormat = [{
+    attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res)
+  }];
+}
+
 def AMDGPU_PackedTrunc2xFp8Op :
     AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
     Arguments<(ins F32:$sourceA,
@@ -139,6 +166,35 @@ def AMDGPU_PackedTrunc2xFp8Op :
   let hasVerifier = 1;
 }
 
+def AMDGPU_PackedScaledTrunc2xFp8Op :
+    AMDGPU_Op<"packed_scaled_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
+    Arguments<(ins F32:$sourceA,
+      Optional<F32>:$sourceB,
+      F32:$scale,
+      ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
+      Optional<FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>>:$existing)>,
+    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
+  let summary = "Round two floats into a packed vector of 8-bit floats";
+  let description = [{
+    Scale and round the inputs `sourceA` and `sourceB` (which is undefined if not
+    specified) into the low or high word (bottom two or top two) elements
+    of the returned vector, keeping the other two elements of `existing`
+    unchanged if present (or undefined if it was not passed in).
+
+    The reason for this odd signature is that AMD GPUs cannot easily work with
+    sub-registers, and so the conversion intrinsics (which are currently the
+    only way to work with 8-bit float types) take packed vectors of 4 8-bit
+    values.
+  }];
+  let assemblyFormat = [{
+    attr-dict $sourceA `,` ($sourceB^):(`undef`)?
+    `into` ($existing^):(`undef`)? `[` `word` $wordIndex `]`
+    `,` $scale
+    `:` type($sourceA) `to` type($res) (`into` type($existing)^)?
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_PackedStochRoundFp8Op :
     AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
     Arguments<(ins F32:$source,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c5094799bbef7..5fc8e370ac4c4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1148,6 +1148,19 @@ struct ExtPackedFp8OpLowering final
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+struct ScaledExtPackedFp8OpLowering final
+    : public ConvertOpToLLVMPattern<ScaledExtPackedFp8Op> {
+  ScaledExtPackedFp8OpLowering(const LLVMTypeConverter &converter,
+                               Chipset chipset)
+      : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedFp8Op>(converter),
+        chipset(chipset) {}
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 struct PackedTrunc2xFp8OpLowering final
     : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
   PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1161,6 +1174,20 @@ struct PackedTrunc2xFp8OpLowering final
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+struct PackedScaledTrunc2xFp8OpLowering final
+    : public ConvertOpToLLVMPattern<PackedScaledTrunc2xFp8Op> {
+  PackedScaledTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
+                                   Chipset chipset)
+      : ConvertOpToLLVMPattern<amdgpu::PackedScaledTrunc2xFp8Op>(converter),
+        chipset(chipset) {}
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(PackedScaledTrunc2xFp8Op op,
+                  PackedScaledTrunc2xFp8OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 struct PackedStochRoundFp8OpLowering final
     : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
   PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1229,6 +1256,67 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   }
   return success();
 }
+// rocdl.cvt.scalef32.pk.f32.fp8 %source[false]: i32, %c4: f32 : vector<2xf32>
+// rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
+
+// amdgpu.scaled_ext_packed_fp8 %v[0]: f8E5M2, %scale: f32 : f8E5M2 to
+// vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]: vector<2xf8E5M2>, %scale:
+// f32 : vector<2xf8E5M2> to vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]:
+// vector<4xf8E5M2>, %scale: f32 : vector<4xf8E5M2> to vector<2xf32>
+LogicalResult ScaledExtPackedFp8OpLowering::matchAndRewrite(
+    ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  if (chipset != kGfx950)
+    return rewriter.notifyMatchFailure(
+        loc, "Scaled fp8 conversion instructions are not available on target "
+             "architecture and their emulation is not implemented");
+  Type v4i8 =
+      getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
+  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
+  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
+
+  Value source = adaptor.getSource();
+  Value scale = adaptor.getScale();
+  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
+  auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
+  Type sourceElemType = getElementTypeOrSelf(op.getSource());
+  // Extend to a v4i8
+  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
+    Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
+    if (!sourceVecType) {
+      longVec = rewriter.create<LLVM::InsertElementOp>(
+          loc, longVec, source, createI32Constant(rewriter, loc, 0));
+    } else {
+      for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
+        Value idx = createI32Constant(rewriter, loc, i);
+        Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+        longVec =
+            rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+      }
+    }
+    source = longVec;
+  }
+  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+  if (resultVecType) {
+    if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
+          op, f32, i32Source, scale, op.getIndex());
+    } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
+          op, f32, i32Source, scale, op.getIndex());
+    }
+  } else {
+    if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Bf8Op>(
+          op, f32, i32Source, scale, op.getIndex());
+    } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Fp8Op>(
+          op, f32, i32Source, scale, op.getIndex());
+    }
+  }
+  return success();
+}
 
 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
     PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
@@ -1266,6 +1354,46 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
   return success();
 }
 
+// rocdl.cvt.scalef32.pk.fp8.f32 %sourceA: f32, %sourceB: f32, %c0: f32 ->
+// %old[false]: vector<2xi16> : vector<2xi16>
+LogicalResult PackedScaledTrunc2xFp8OpLowering::matchAndRewrite(
+    PackedScaledTrunc2xFp8Op op, PackedScaledTrunc2xFp8OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  if (chipset != kGfx950)
+    return rewriter.notifyMatchFailure(
+        loc, "Scaled fp8 conversion instructions are not available on target "
+             "architecture and their emulation is not implemented");
+  Type v2i16 = getTypeConverter()->convertType(
+      VectorType::get(2, rewriter.getI16Type()));
+
+  Type resultType = op.getResult().getType();
+  Type resultElemType = getElementTypeOrSelf(resultType);
+
+  Value sourceA = adaptor.getSourceA();
+  Value sourceB = adaptor.getSourceB();
+  Value scale = adaptor.getScale();
+  if (!sourceB)
+    sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
+  Value existing = adaptor.getExisting();
+  if (existing)
+    existing = rewriter.create<LLVM::BitcastOp>(loc, v2i16, existing);
+  else
+    existing = rewriter.create<LLVM::UndefOp>(loc, v2i16);
+
+  Value result;
+  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
+        loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
+  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
+        loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
+
+  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
+      op, getTypeConverter()->convertType(resultType), result);
+  return success();
+}
+
 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
     PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -1547,7 +1675,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
            MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
-           ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+           ExtPackedFp8OpLowering, ScaledExtPackedFp8OpLowering,
+           PackedTrunc2xFp8OpLowering, PackedScaledTrunc2xFp8OpLowering,
            PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
                                                                  chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index a0a98a4e86721..b24a185d21180 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -54,6 +54,12 @@ LogicalResult PackedTrunc2xFp8Op::verify() {
   return success();
 }
 
+LogicalResult PackedScaledTrunc2xFp8Op::verify() {
+  if (getExisting() && getExisting().getType() != getResult().getType())
+    return emitOpError("existing values must have same type as result");
+  return success();
+}
+
 LogicalResult PackedStochRoundFp8Op::verify() {
   if (getExisting() && getExisting().getType() != getResult().getType())
     return emitOpError("existing values must have same type as result");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir
new file mode 100644
index 0000000000000..128b8eabd76cd
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+
+// CHECK-LABEL: func @scaled_ext_scalar
+// CHECK-SAME: ([[IN:%.+]]: f8E5M2, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : f8E5M2 to i8
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.bf8 [[CAST]][0], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_scalar(%v: f8E5M2, %scale: f32) -> f32 {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale: f8E5M2 to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_short_vec
+// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][1], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_short_vec(%v: vector<2xf8E4M3FN>, %scale: f32) -> f32 {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<2xf8E4M3FN> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_full_vec
+// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][3], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_full_vec(%v: vector<4xf8E4M3FN>, %scale: f32) -> f32 {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[3], %scale : vector<4xf8E4M3FN> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_2xfp8
+// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][false], [[SCALE]] : vector<2xf32>
+// CHECK: return [[EXT]]
+func.func @scaled_ext_packed_2xfp8(%v: vector<2xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<2xf8E4M3FN> to vector<2xf32>
+  func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_4xfp8
+// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][true], [[SCALE]] : vector<2xf32>
+// CHECK: return [[EXT]] : vector<2xf32>
+func.func @scaled_ext_packed_4xfp8(%v: vector<4xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<4xf8E4M3FN> to vector<2xf32>
+  func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @packed_scaled_trunc
+// CHECK-SAME: ([[V:%.+]]: f32, [[SCALE:%.+]]: f32)
+// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[V2]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_scaled_trunc(%v: f32, %scale: f32) -> vector<4xf8E4M3FN> {
+  %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, undef into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
+  func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_scaled_truncx2
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[SCALE:%.+]]: f32)
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_scaled_truncx2(%v: f32, %w: f32, %scale: f32) -> vector<4xf8E4M3FN> {
+  %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
+  func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_scaled_truncx2_into
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>, [[SCALE:%.+]]: f32)
+// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
+// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.bf8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING_INT]][true] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
+func.func @packed_scaled_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>, %scale: f32) -> vector<4xf8E5M2> {
+  %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into %existing[word 1], %scale : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
+  func.return %ret : vector<4xf8E5M2>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 188cfcc4eb38b..d1d56bd3b5178 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -18,6 +18,20 @@ func.func @ext_packed_fp8_v(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
   func.return %ret : vector<2xf32>
 }
 
+// CHECK-LABEL: func @scaled_ext_packed_fp8_s
+// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to f32
+func.func @scaled_ext_packed_fp8_s(%v: vector<4xf8E5M2>, %scale: f32) -> f32 {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_fp8_v
+// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to vector<2xf32
+func.func @scaled_ext_packed_fp8_v(%v: vector<4xf8E5M2>, %scale: f32) -> vector<2xf32> {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to vector<2xf32>
+  func.return %ret : vector<2xf32>
+}
+
 // CHECK-LABEL: func @packed_trunc_2xfp8
 // CHECK: amdgpu.packed_trunc_2xfp8
 func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> {
@@ -25,6 +39,13 @@ func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>,
   func.return %ret : vector<4xf8E5M2FNUZ>
 }
 
+// CHECK-LABEL: func @scaled_packed_trunc_2xfp8
+// ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-mlir-amdgpu

Author: Tim Gymnich (tgymnich)

Changes

implement ScaledExtPackedFp8Op and PackedScaledTrunc2xFp8Op


Patch is 20.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141554.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+56)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+130-1)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+6)
  • (added) mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir (+108)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+21)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 02308568c1ad1..301705bd1786b 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -112,6 +112,33 @@ def AMDGPU_ExtPackedFp8Op :
   }];
 }
 
+def AMDGPU_ScaledExtPackedFp8Op :
+    AMDGPU_Op<"scaled_ext_packed_fp8", [Pure]>,
+    Arguments<(ins AnyTypeOf<[F8E5M2, F8E4M3FN,
+        VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>]>:$source, 
+      F32:$scale,
+      ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
+    Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> {
+  let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats";
+
+  let description = [{
+    Extend and scale one or two 8-bit floats in `source[index]` to a 32-bit float or
+    two floats and return them.
+
+    This rather unusual signature arises from the fact that AMD GPUs cannot
+    easily work with sub 32-bit quantities, so the compiler intrinsics for
+    extending 8-bit floats (which are, currently, the only way to work with
+    this operation) take packed vectors of 2 such floats.
+
+    If the passed-in vector has fewer than two elements, or the input is scalar,
+    the remaining values in the <2 x i8> will be filled with
+    undefined values as needed.
+  }];
+  let assemblyFormat = [{
+    attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res)
+  }];
+}
+
 def AMDGPU_PackedTrunc2xFp8Op :
     AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
     Arguments<(ins F32:$sourceA,
@@ -139,6 +166,35 @@ def AMDGPU_PackedTrunc2xFp8Op :
   let hasVerifier = 1;
 }
 
+def AMDGPU_PackedScaledTrunc2xFp8Op :
+    AMDGPU_Op<"packed_scaled_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
+    Arguments<(ins F32:$sourceA,
+      Optional<F32>:$sourceB,
+      F32:$scale,
+      ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
+      Optional<FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>>:$existing)>,
+    Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
+  let summary = "Round two floats into a packed vector of 8-bit floats";
+  let description = [{
+    Scale and round the inputs `sourceA` and `sourceB` (which is undefined if not
+    specified) into the low or high word (bottom two or top two) elements
+    of the returned vector, keeping the other two elements of `existing`
+    unchanged if present (or undefined if it was not passed in).
+
+    The reason for this odd signature is that AMD GPUs cannot easily work with
+    sub-registers, and so the conversion intrinsics (which are currently the
+    only way to work with 8-bit float types) take packed vectors of 4 8-bit
+    values.
+  }];
+  let assemblyFormat = [{
+    attr-dict $sourceA `,` ($sourceB^):(`undef`)?
+    `into` ($existing^):(`undef`)? `[` `word` $wordIndex `]`
+    `,` $scale
+    `:` type($sourceA) `to` type($res) (`into` type($existing)^)?
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_PackedStochRoundFp8Op :
     AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
     Arguments<(ins F32:$source,
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index c5094799bbef7..5fc8e370ac4c4 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1148,6 +1148,19 @@ struct ExtPackedFp8OpLowering final
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+struct ScaledExtPackedFp8OpLowering final
+    : public ConvertOpToLLVMPattern<ScaledExtPackedFp8Op> {
+  ScaledExtPackedFp8OpLowering(const LLVMTypeConverter &converter,
+                               Chipset chipset)
+      : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedFp8Op>(converter),
+        chipset(chipset) {}
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 struct PackedTrunc2xFp8OpLowering final
     : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
   PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1161,6 +1174,20 @@ struct PackedTrunc2xFp8OpLowering final
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+struct PackedScaledTrunc2xFp8OpLowering final
+    : public ConvertOpToLLVMPattern<PackedScaledTrunc2xFp8Op> {
+  PackedScaledTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
+                                   Chipset chipset)
+      : ConvertOpToLLVMPattern<amdgpu::PackedScaledTrunc2xFp8Op>(converter),
+        chipset(chipset) {}
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(PackedScaledTrunc2xFp8Op op,
+                  PackedScaledTrunc2xFp8OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 struct PackedStochRoundFp8OpLowering final
     : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
   PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1229,6 +1256,67 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
   }
   return success();
 }
+// rocdl.cvt.scalef32.pk.f32.fp8 %source[false]: i32, %c4: f32 : vector<2xf32>
+// rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
+
+// amdgpu.scaled_ext_packed_fp8 %v[0]: f8E5M2, %scale: f32 : f8E5M2 to
+// vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]: vector<2xf8E5M2>, %scale:
+// f32 : vector<2xf8E5M2> to vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]:
+// vector<4xf8E5M2>, %scale: f32 : vector<4xf8E5M2> to vector<2xf32>
+LogicalResult ScaledExtPackedFp8OpLowering::matchAndRewrite(
+    ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  if (chipset != kGfx950)
+    return rewriter.notifyMatchFailure(
+        loc, "Scaled fp8 conversion instructions are not available on target "
+             "architecture and their emulation is not implemented");
+  Type v4i8 =
+      getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
+  Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
+  Type f32 = getTypeConverter()->convertType(op.getResult().getType());
+
+  Value source = adaptor.getSource();
+  Value scale = adaptor.getScale();
+  auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
+  auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
+  Type sourceElemType = getElementTypeOrSelf(op.getSource());
+  // Extend to a v4i8
+  if (!sourceVecType || sourceVecType.getNumElements() < 4) {
+    Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
+    if (!sourceVecType) {
+      longVec = rewriter.create<LLVM::InsertElementOp>(
+          loc, longVec, source, createI32Constant(rewriter, loc, 0));
+    } else {
+      for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
+        Value idx = createI32Constant(rewriter, loc, i);
+        Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+        longVec =
+            rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+      }
+    }
+    source = longVec;
+  }
+  Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+  if (resultVecType) {
+    if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
+          op, f32, i32Source, scale, op.getIndex());
+    } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
+          op, f32, i32Source, scale, op.getIndex());
+    }
+  } else {
+    if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Bf8Op>(
+          op, f32, i32Source, scale, op.getIndex());
+    } else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
+      rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Fp8Op>(
+          op, f32, i32Source, scale, op.getIndex());
+    }
+  }
+  return success();
+}
 
 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
     PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
@@ -1266,6 +1354,46 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
   return success();
 }
 
+// rocdl.cvt.scalef32.pk.fp8.f32 %sourceA: f32, %sourceB: f32, %c0: f32 ->
+// %old[false]: vector<2xi16> : vector<2xi16>
+LogicalResult PackedScaledTrunc2xFp8OpLowering::matchAndRewrite(
+    PackedScaledTrunc2xFp8Op op, PackedScaledTrunc2xFp8OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  if (chipset != kGfx950)
+    return rewriter.notifyMatchFailure(
+        loc, "Scaled fp8 conversion instructions are not available on target "
+             "architecture and their emulation is not implemented");
+  Type v2i16 = getTypeConverter()->convertType(
+      VectorType::get(2, rewriter.getI16Type()));
+
+  Type resultType = op.getResult().getType();
+  Type resultElemType = getElementTypeOrSelf(resultType);
+
+  Value sourceA = adaptor.getSourceA();
+  Value sourceB = adaptor.getSourceB();
+  Value scale = adaptor.getScale();
+  if (!sourceB)
+    sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
+  Value existing = adaptor.getExisting();
+  if (existing)
+    existing = rewriter.create<LLVM::BitcastOp>(loc, v2i16, existing);
+  else
+    existing = rewriter.create<LLVM::UndefOp>(loc, v2i16);
+
+  Value result;
+  if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
+        loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
+  else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
+    result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
+        loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
+
+  result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
+      op, getTypeConverter()->convertType(resultType), result);
+  return success();
+}
+
 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
     PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -1547,7 +1675,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
            MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
-           ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+           ExtPackedFp8OpLowering, ScaledExtPackedFp8OpLowering,
+           PackedTrunc2xFp8OpLowering, PackedScaledTrunc2xFp8OpLowering,
            PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
                                                                  chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index a0a98a4e86721..b24a185d21180 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -54,6 +54,12 @@ LogicalResult PackedTrunc2xFp8Op::verify() {
   return success();
 }
 
+LogicalResult PackedScaledTrunc2xFp8Op::verify() {
+  if (getExisting() && getExisting().getType() != getResult().getType())
+    return emitOpError("existing values must have same type as result");
+  return success();
+}
+
 LogicalResult PackedStochRoundFp8Op::verify() {
   if (getExisting() && getExisting().getType() != getResult().getType())
     return emitOpError("existing values must have same type as result");
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir
new file mode 100644
index 0000000000000..128b8eabd76cd
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/8-bit-floats-scaled.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
+
+// CHECK-LABEL: func @scaled_ext_scalar
+// CHECK-SAME: ([[IN:%.+]]: f8E5M2, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : f8E5M2 to i8
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.bf8 [[CAST]][0], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_scalar(%v: f8E5M2, %scale: f32) -> f32 {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale: f8E5M2 to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_short_vec
+// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][1], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_short_vec(%v: vector<2xf8E4M3FN>, %scale: f32) -> f32 {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<2xf8E4M3FN> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_full_vec
+// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][3], [[SCALE]] : f32
+// CHECK: return [[EXT]] : f32
+func.func @scaled_ext_full_vec(%v: vector<4xf8E4M3FN>, %scale: f32) -> f32 {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[3], %scale : vector<4xf8E4M3FN> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_2xfp8
+// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
+// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
+// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
+// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
+// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
+// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
+// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][false], [[SCALE]] : vector<2xf32>
+// CHECK: return [[EXT]]
+func.func @scaled_ext_packed_2xfp8(%v: vector<2xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<2xf8E4M3FN> to vector<2xf32>
+  func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_4xfp8
+// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
+// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
+// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][true], [[SCALE]] : vector<2xf32>
+// CHECK: return [[EXT]] : vector<2xf32>
+func.func @scaled_ext_packed_4xfp8(%v: vector<4xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<4xf8E4M3FN> to vector<2xf32>
+  func.return %ret : vector<2xf32>
+}
+
+// CHECK-LABEL: func @packed_scaled_trunc
+// CHECK-SAME: ([[V:%.+]]: f32, [[SCALE:%.+]]: f32)
+// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[V2]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_scaled_trunc(%v: f32, %scale: f32) -> vector<4xf8E4M3FN> {
+  %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, undef into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
+  func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_scaled_truncx2
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[SCALE:%.+]]: f32)
+// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
+func.func @packed_scaled_truncx2(%v: f32, %w: f32, %scale: f32) -> vector<4xf8E4M3FN> {
+  %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
+  func.return %ret : vector<4xf8E4M3FN>
+}
+
+// CHECK-LABEL: func @packed_scaled_truncx2_into
+// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>, [[SCALE:%.+]]: f32)
+// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
+// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to vector<2xi16>
+// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.bf8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING_INT]][true] : vector<2xi16>
+// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
+// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
+func.func @packed_scaled_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>, %scale: f32) -> vector<4xf8E5M2> {
+  %ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into %existing[word 1], %scale : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
+  func.return %ret : vector<4xf8E5M2>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index 188cfcc4eb38b..d1d56bd3b5178 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -18,6 +18,20 @@ func.func @ext_packed_fp8_v(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
   func.return %ret : vector<2xf32>
 }
 
+// CHECK-LABEL: func @scaled_ext_packed_fp8_s
+// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to f32
+func.func @scaled_ext_packed_fp8_s(%v: vector<4xf8E5M2>, %scale: f32) -> f32 {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to f32
+  func.return %ret : f32
+}
+
+// CHECK-LABEL: func @scaled_ext_packed_fp8_v
+// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to vector<2xf32
+func.func @scaled_ext_packed_fp8_v(%v: vector<4xf8E5M2>, %scale: f32) -> vector<2xf32> {
+  %ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to vector<2xf32>
+  func.return %ret : vector<2xf32>
+}
+
 // CHECK-LABEL: func @packed_trunc_2xfp8
 // CHECK: amdgpu.packed_trunc_2xfp8
 func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> {
@@ -25,6 +39,13 @@ func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>,
   func.return %ret : vector<4xf8E5M2FNUZ>
 }
 
+// CHECK-LABEL: func @scaled_packed_trunc_2xfp8
+// ...
[truncated]

@tgymnich tgymnich changed the title [mlir][AMDGPU] implement ScaledExtPackedFp8Op and PackedScaledTrunc2xFp8Op [mlir][AMDGPU] Add scaled floating point conversion ops May 27, 2025
@tgymnich tgymnich changed the title [mlir][AMDGPU] Add scaled floating point conversion ops [mlir][AMDGPU] Add scaled floating point conversion ops fp8 May 27, 2025
@tgymnich tgymnich force-pushed the tim/scaled-fp-conv branch from a875a16 to 8e88589 Compare May 27, 2025 09:21
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these can land as-is, however:

1.Would it be reasonable to app (maybe as a separate op, since the signature is different) support for the packed truncation operations that take a <2 x f16> or <2 x bf16> as an input?
2. On that note, making the truncation operation take a F32 and an Optional<F32> feels a bit awkward and might make things hard to lower to. Maybe we should have the truncation operation take either a F32, or a <2 x f32> and have the vector be split apart during lowering? (I'm not 100% certain on this?)
3. There are analogous conversions to/from packed FP4 operations (though they can't return a scalar). Given that, it might make sense to have these operations not have "fp8" in their name and letting them handle the fp4 instructions as well? (The fp6 ones probably need their own op though, since they don't have tied inputs or the like)

None of this is blocking if you want to land these now, but I wanted to flag improvements I have in mind while you have the bandwidth.

(I'm not marking approval night now just to prevent confusion. If you reply that you'd like to just land these I'll approve.)

@tgymnich tgymnich force-pushed the tim/scaled-fp-conv branch from 8e88589 to c359b94 Compare June 4, 2025 16:04
Copy link

github-actions bot commented Jun 4, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@tgymnich tgymnich force-pushed the tim/scaled-fp-conv branch from c359b94 to a8243eb Compare June 4, 2025 16:14
@tgymnich
Copy link
Member Author

tgymnich commented Jun 4, 2025

@krzysz00 I went ahead and made the operations also support fp4. I did not include fp6, since the packing behaves completely differently (e.g. no selectable indices).
We should think about naming again maybe add the packing width to the name?
I'll need to follow up with another operator for the non-packed cvt.scalef32. variants.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment, otherwise looks good to me

def AMDGPU_ScaledExtPackedOp
: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
Arguments<(
ins AnyTypeOf<[VectorOfLengthAndType<[2, 3, 4], [F8E5M2, F8E4M3FN]>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can allow 1 here?

Copy link
Member Author

@tgymnich tgymnich Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you think this should be implemented?
(1) using the non-pk instructions
(2) by leaving one of the 2 input vector elements undefined / zero (potentially inefficient).
(3) combination of the above

The non-pk instructions are missing the bf16 cases and the f16 cases have different semantics (e.g. they take an existing vector input for the result to be packed into).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more thinking that if there isn't a scalar instruction, we'll want to pad with 0s, but if there is one then never mind.

... and can you point me at the bf16 asymmetry? I can't think of one off the top of my head

(Also, if there are non-packing instructions, we may want to only take even-sized vectors and let -arith-to-amdgpu handle the splitting into packed + non-packed instructions)

Copy link
Member Author

@tgymnich tgymnich Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... and can you point me at the bf16 asymmetry?

There is cvt.scalef32.f16.fp8 but I could not find cvt.scalef32.bf16.fp8.

(Also, if there are non-packing instructions, we may want to only take even-sized vectors and let -arith-to-amdgpu handle the splitting into packed + non-packed instructions)

According to the function signature cvt.scalef32.f16.fp8 still seems to be a packing instruction even though its missing the pk in the name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That conversion isn't packed in the input - it only converts on single bf8/fp8, but it does still have to update one half of a v2f16 with the result

: AMDGPU_Op<"scaled_ext_packed", [Pure]>,
Arguments<(
ins AnyTypeOf<[VectorOfLengthAndType<[2, 3, 4], [F8E5M2, F8E4M3FN]>,
VectorOfLengthAndType<[2, 3, 4, 5, 6, 7, 8],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I suppose odd numbers work, since you can just zext

// Extend to a packedVectorType
if (!sourceVecType ||
sourceVecType.getNumElements() < packedVecType.getNumElements()) {
Value longVec = rewriter.create<LLVM::UndefOp>(loc, packedVecType);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we've hit issues with this before, I'd go with constant 0s here

@tgymnich tgymnich force-pushed the tim/scaled-fp-conv branch from a8243eb to 7b643d0 Compare June 5, 2025 15:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants