Skip to content

[MLIR][NVVM] Update support for conversions to f8x2 and f6x2 types #137781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2025

Conversation

Wolfram70
Copy link
Contributor

@Wolfram70 Wolfram70 commented Apr 29, 2025

This change:

  • Adds the cvt.f32x2.to.f8x2, cvt.f16x2.to.f8x2, and cvt.bf16x2.to.f8x2
    Ops to the NVVM dialect for the conversions to .e4m3x2, e5m2x2,
    and .ue8m0x2 types.
  • Renames the recently added cvt.to.f6x2 Op to cvt.f32x2.to.f6x2
    for consistency with the other conversion Ops.

For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt

@Wolfram70 Wolfram70 requested a review from durga4github April 29, 2025 10:27
@Wolfram70 Wolfram70 self-assigned this Apr 29, 2025
@Wolfram70 Wolfram70 requested a review from grypp as a code owner April 29, 2025 10:27
@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2025

@llvm/pr-subscribers-mlir

Author: Srinivasa Ravi (Wolfram70)

Changes

This patch adds the cvt.to.f8x2 NVVM dialect Op for conversion into f8x2 types.

For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt


Full diff: https://github.com/llvm/llvm-project/pull/137781.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+104)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+97)
  • (added) mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir (+71)
  • (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+88)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 27d54e7abeda9..f5eb91bc029f5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1120,6 +1120,110 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
   }];
 }
 
+def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
+def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
+def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
+
+def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
+  [CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtToF8x2Op : NVVM_Op<"cvt.to.f8x2"> {
+  let summary = "Convert a pair of f32 or fp16 inputs to f8x2";
+  let description = [{
+    This Op converts each of the given float input types to the specified f8 
+    type.
+    The result `dst` is either represented as an i16 type or a vector
+    of two f8 types.
+    The following table describes the supported conversions and their formats:
+    ```
+    |-----------|-----------|--------------------------------------------------|
+    | Src Type  | Dst Type  | Description                                      |
+    |-----------|-----------|--------------------------------------------------|
+    |   f16x2   |  e4m3x2   |  Only operand `a` must be provided and it must   |
+    |           |  e5m2x2   |  be a vector of two F16s.                        |
+    |           |           |  If `dst` is returned as an i16 type, the        |
+    |           |           |  converted values are packed such that the       |
+    |           |           |  value converted from the first element of `a`   |
+    |           |           |  is stored in the upper 8 bits of `dst` and the  |
+    |           |           |  value converted from the second element of `a`  |
+    |           |           |  is stored in the lower 8 bits of `dst`.         |
+    |           |           |  If `dst` is returned as a vector type, each     |
+    |           |           |  converted value from `a` is stored as an i8     |
+    |           |           |  element in the vector.                          |
+    |-----------|-----------|--------------------------------------------------|
+    |   bf16x2  |  ue8m0x2  |  Only operand `a` must be provided and it must   |
+    |           |           |  be a vector of two BF16s.                       |
+    |           |           |  If `dst` is returned as an i16 type, the        |
+    |           |           |  converted values are packed such that the       |
+    |           |           |  value converted from the first element of `a`   |
+    |           |           |  is stored in the upper 8 bits of `dst` and the  |
+    |           |           |  value converted from the second element of `a`  |
+    |           |           |  is stored in the lower 8 bits of `dst`.         |
+    |           |           |  If `dst` is returned as a vector type, each     |
+    |           |           |  converted value from `a` is stored as an i8     |
+    |           |           |  element in the vector.                          |
+    |-----------|-----------|--------------------------------------------------|
+    |  f32, f32 |  e4m3x2   |  Both operands `a` and `b` must be provided and  |
+    |           |  e5m2x2   |  they must be F32 values.                        |
+    |           |  ue8m0x2  |  If `dst` is returned as an i16 type, the        |
+    |           |           |  converted values are packed such that the       |
+    |           |           |  value converted from `a` is stored in the       |
+    |           |           |  upper 8 bits of `dst` and the value converted   |
+    |           |           |  from `b` is stored in the lower 8 bits of       |
+    |           |           |  `dst`. If `dst` is returned as a vector type,   |
+    |           |           |  each converted value is stored as an i8         |
+    |           |           |  element in the vector.                          |
+    |-----------|-----------|--------------------------------------------------|
+    ```
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction for conversions to the signed f8 types (e4m3 and e5m2).
+    The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+  let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+  let arguments = (ins 
+    CVTFP8TypeAttr:$type,
+    AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16, BF16]>]>:$a,
+    Optional<F32>:$b,
+    DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
+    DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+    DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat = "$type $a (`,` $b^)? attr-dict `:` type($a) (`,` type($b)^)? `->` type($dst)";
+  
+  let extraClassDeclaration = [{
+    bool isFromF32Type();
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
+                                              bool isFromF32Type,
+                                              NVVM::FPRoundingMode rnd,
+                                              NVVM::SaturationMode sat,
+                                              bool hasRelu);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::CvtToF8x2Op::getIntrinsicID($type, op.isFromF32Type(), $rnd, $sat, $relu);
+    llvm::Value *packedI16;
+    if(op.isFromF32Type())
+      packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+    else
+      packedI16 = createIntrinsicCall(builder, intId, {$a});
+    if(op.getDst().getType().isInteger(16))
+      $dst = packedI16;
+    else
+      $dst = builder.CreateBitCast(packedI16,
+                      llvm::FixedVectorType::get(
+                        llvm::Type::getInt8Ty(builder.getContext()), 2));
+  }];
+  
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM MMA Ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 18453aa7f6ea9..c30c45abbdd02 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -32,6 +32,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Attributes.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/Type.h"
@@ -133,6 +134,68 @@ LogicalResult CvtFloatToTF32Op::verify() {
   return success();
 }
 
+bool CvtToF8x2Op::isFromF32Type() { return getA().getType().isF32(); }
+
+LogicalResult CvtToF8x2Op::verify() {
+  bool isFromF32 = false;
+  bool isFromF16x2 = false;
+  bool isFromBF16x2 = false;
+
+  bool isRoundingModeRN = getRnd() == NVVM::FPRoundingMode::RN;
+  bool isRoundingModeRZ = getRnd() == NVVM::FPRoundingMode::RZ;
+  bool isRoundingModeRP = getRnd() == NVVM::FPRoundingMode::RP;
+
+  bool isSatFinite = getSat() == NVVM::SaturationMode::SATFINITE;
+
+  bool hasRelu = getRelu();
+
+  if (auto vecType = dyn_cast<VectorType>(getA().getType())) {
+    isFromF16x2 = vecType.getElementType().isF16();
+    isFromBF16x2 = vecType.getElementType().isBF16();
+  } else {
+    isFromF32 = true;
+  }
+
+  if (isFromF32) {
+    if (!(getODSOperands(1).size() > 0))
+      return emitOpError("expected two f32 inputs for converting from f32");
+  } else {
+    if (getODSOperands(1).size() > 0)
+      return emitOpError(
+          "expected only a single f32, vector<2xf16> or vector<2xbf16> input "
+          "for converting from f16x2 or bf16x2, got two inputs instead.");
+  }
+
+  switch (getType()) {
+  case NVVM::CVTFP8Type::E4M3:
+  case NVVM::CVTFP8Type::E5M2:
+    if (!(isFromF32 || isFromF16x2))
+      return emitOpError("expected f32 or f16x2 input for conversions to "
+                         ".e4m3x2 or .e5m2x2 types");
+    if (!isRoundingModeRN)
+      return emitOpError("RN rounding mode required for conversions to .e4m3x2 "
+                         "or .e5m2x2 types");
+    if (!isSatFinite)
+      return emitOpError("SATFINITE saturation mode required for conversions "
+                         "to .e4m3x2 or .e5m2x2 types");
+    break;
+  case NVVM::CVTFP8Type::UE8M0:
+    if (!(isFromF32 || isFromBF16x2))
+      return emitOpError(
+          "expected f32 or bf16x2 input for conversions to .ue8m0x2 type");
+    if (!(isRoundingModeRP || isRoundingModeRZ))
+      return emitOpError(
+          "RP or RZ rounding mode required for conversions to .ue8m0x2 type");
+    if (hasRelu)
+      return emitOpError("relu not supported for conversions to .ue8m0x2 type");
+    break;
+  default:
+    return emitOpError("unsupported FP8 type");
+  }
+
+  return success();
+}
+
 LogicalResult BulkStoreOp::verify() {
   if (getInitVal() != 0)
     return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1304,6 +1367,40 @@ llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
   }
 }
 
+#define CVT_TO_UE8M0X2_IMPL(fromtype, rndm, has_sat)                           \
+  has_sat ? llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm##_satfinite    \
+          : llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm
+
+#define GET_CVT_TO_UE8M0X2_ID(fromtype, rnd, has_sat)                          \
+  (rnd == NVVM::FPRoundingMode::RZ)                                            \
+      ? CVT_TO_UE8M0X2_IMPL(fromtype, _rz, has_sat)                            \
+      : CVT_TO_UE8M0X2_IMPL(fromtype, _rp, has_sat)
+
+#define GET_CVT_TO_F8X2_ID(fromtype, totype, has_relu)                         \
+  has_relu ? llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn_relu          \
+           : llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn
+
+llvm::Intrinsic::ID CvtToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type to,
+                                                bool isFromF32Type,
+                                                NVVM::FPRoundingMode rnd,
+                                                NVVM::SaturationMode sat,
+                                                bool hasRelu) {
+  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
+
+  switch (to) {
+  case NVVM::CVTFP8Type::E4M3:
+    return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e4m3x2, hasRelu)
+                         : GET_CVT_TO_F8X2_ID(f16x2, e4m3x2, hasRelu);
+  case NVVM::CVTFP8Type::E5M2:
+    return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e5m2x2, hasRelu)
+                         : GET_CVT_TO_F8X2_ID(f16x2, e5m2x2, hasRelu);
+  case NVVM::CVTFP8Type::UE8M0:
+    return isFromF32Type ? GET_CVT_TO_UE8M0X2_ID(ff, rnd, hasSatFinite)
+                         : GET_CVT_TO_UE8M0X2_ID(bf16x2, rnd, hasSatFinite);
+  }
+  llvm_unreachable("Invalid CVTFP8Type for CvtToF8x2Op");
+}
+
 llvm::Intrinsic::ID
 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
                                       LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
new file mode 100644
index 0000000000000..06573ce53676f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_fp8x2_packed
+llvm.func @convert_float_to_fp8x2_packed(%srcA : f32, %srcB : f32) -> !llvm.void {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp8x2_vector
+llvm.func @convert_float_to_fp8x2_vector(%srcA : f32, %srcB : f32) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
+  // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
+  %res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> vector<2xi8>
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp8x2_with_relu
+llvm.func @convert_float_to_fp8x2_with_relu(%srcA : f32, %srcB : f32) -> !llvm.void {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp8x2
+llvm.func @convert_f16x2_to_fp8x2(%src : vector<2xf16>) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
+  llvm.return
+}
+
+
+// CHECK-LABEL: @convert_bf16x2_to_fp8x2
+llvm.func @convert_bf16x2_to_fp8x2(%src : vector<2xbf16>) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xbf16> -> vector<2xi8>
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp8x2_with_relu
+llvm.func @convert_f16x2_to_fp8x2_with_relu(%src : vector<2xf16>) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index f87f11daeef54..fc00ea6ee7003 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -176,3 +176,91 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
   %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
   llvm.return
 }
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e4m3(%a : f32, %b : f32) {
+  // expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e5m2(%a : f32, %b : f32) {
+  // expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_ue8m0(%a : f32, %b : f32) {
+  // expected-error @below {{RP or RZ rounding mode required for conversions to .ue8m0x2 type}}
+  %res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e4m3(%a : f32, %b : f32) {
+  // expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e5m2(%a : f32, %b : f32) {
+  // expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) {
+  // expected-error @below {{relu not supported for conversions to .ue8m0x2 type}}
+  %res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e4m3(%src : vector<2xbf16>) {
+  // expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %src : vector<2xbf16> -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e5m2(%src : vector<2xbf16>) {
+  // expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e5m2> %src : vector<2xbf16> -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_ue8m0(%src : vector<2xf16>) {
+  // expected-error @below {{expected f32 or bf16x2 input for conversions to .ue8m0x2 type}}
+  %res = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16> -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_two_inputs_with_fromfp16x2(%src : vector<2xf16>, %b : f32) {
+  // expected-error @below {{expected only a single f32, vector<2xf16> or vector<2xbf16> input for converting from f16x2 or bf16x2, got two inputs instead.}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %src, %b : vector<2xf16>, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_missing_second_input(%a : f32) {
+  // expected-error @below {{expected two f32 inputs for converting from f32}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %a {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : f32 -> i16
+  llvm.return
+}

@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Srinivasa Ravi (Wolfram70)

Changes

This patch adds the cvt.to.f8x2 NVVM dialect Op for conversion into f8x2 types.

For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt


Full diff: https://github.com/llvm/llvm-project/pull/137781.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+104)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+97)
  • (added) mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir (+71)
  • (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+88)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 27d54e7abeda9..f5eb91bc029f5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1120,6 +1120,110 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
   }];
 }
 
+def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
+def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
+def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
+
+def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
+  [CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtToF8x2Op : NVVM_Op<"cvt.to.f8x2"> {
+  let summary = "Convert a pair of f32 or fp16 inputs to f8x2";
+  let description = [{
+    This Op converts each of the given float input types to the specified f8 
+    type.
+    The result `dst` is either represented as an i16 type or a vector
+    of two f8 types.
+    The following table describes the supported conversions and their formats:
+    ```
+    |-----------|-----------|--------------------------------------------------|
+    | Src Type  | Dst Type  | Description                                      |
+    |-----------|-----------|--------------------------------------------------|
+    |   f16x2   |  e4m3x2   |  Only operand `a` must be provided and it must   |
+    |           |  e5m2x2   |  be a vector of two F16s.                        |
+    |           |           |  If `dst` is returned as an i16 type, the        |
+    |           |           |  converted values are packed such that the       |
+    |           |           |  value converted from the first element of `a`   |
+    |           |           |  is stored in the upper 8 bits of `dst` and the  |
+    |           |           |  value converted from the second element of `a`  |
+    |           |           |  is stored in the lower 8 bits of `dst`.         |
+    |           |           |  If `dst` is returned as a vector type, each     |
+    |           |           |  converted value from `a` is stored as an i8     |
+    |           |           |  element in the vector.                          |
+    |-----------|-----------|--------------------------------------------------|
+    |   bf16x2  |  ue8m0x2  |  Only operand `a` must be provided and it must   |
+    |           |           |  be a vector of two BF16s.                       |
+    |           |           |  If `dst` is returned as an i16 type, the        |
+    |           |           |  converted values are packed such that the       |
+    |           |           |  value converted from the first element of `a`   |
+    |           |           |  is stored in the upper 8 bits of `dst` and the  |
+    |           |           |  value converted from the second element of `a`  |
+    |           |           |  is stored in the lower 8 bits of `dst`.         |
+    |           |           |  If `dst` is returned as a vector type, each     |
+    |           |           |  converted value from `a` is stored as an i8     |
+    |           |           |  element in the vector.                          |
+    |-----------|-----------|--------------------------------------------------|
+    |  f32, f32 |  e4m3x2   |  Both operands `a` and `b` must be provided and  |
+    |           |  e5m2x2   |  they must be F32 values.                        |
+    |           |  ue8m0x2  |  If `dst` is returned as an i16 type, the        |
+    |           |           |  converted values are packed such that the       |
+    |           |           |  value converted from `a` is stored in the       |
+    |           |           |  upper 8 bits of `dst` and the value converted   |
+    |           |           |  from `b` is stored in the lower 8 bits of       |
+    |           |           |  `dst`. If `dst` is returned as a vector type,   |
+    |           |           |  each converted value is stored as an i8         |
+    |           |           |  element in the vector.                          |
+    |-----------|-----------|--------------------------------------------------|
+    ```
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction for conversions to the signed f8 types (e4m3 and e5m2).
+    The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+  let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+  let arguments = (ins 
+    CVTFP8TypeAttr:$type,
+    AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16, BF16]>]>:$a,
+    Optional<F32>:$b,
+    DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
+    DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+    DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat = "$type $a (`,` $b^)? attr-dict `:` type($a) (`,` type($b)^)? `->` type($dst)";
+  
+  let extraClassDeclaration = [{
+    bool isFromF32Type();
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
+                                              bool isFromF32Type,
+                                              NVVM::FPRoundingMode rnd,
+                                              NVVM::SaturationMode sat,
+                                              bool hasRelu);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::CvtToF8x2Op::getIntrinsicID($type, op.isFromF32Type(), $rnd, $sat, $relu);
+    llvm::Value *packedI16;
+    if(op.isFromF32Type())
+      packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+    else
+      packedI16 = createIntrinsicCall(builder, intId, {$a});
+    if(op.getDst().getType().isInteger(16))
+      $dst = packedI16;
+    else
+      $dst = builder.CreateBitCast(packedI16,
+                      llvm::FixedVectorType::get(
+                        llvm::Type::getInt8Ty(builder.getContext()), 2));
+  }];
+  
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM MMA Ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 18453aa7f6ea9..c30c45abbdd02 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -32,6 +32,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Attributes.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/Type.h"
@@ -133,6 +134,68 @@ LogicalResult CvtFloatToTF32Op::verify() {
   return success();
 }
 
+bool CvtToF8x2Op::isFromF32Type() { return getA().getType().isF32(); }
+
+LogicalResult CvtToF8x2Op::verify() {
+  bool isFromF32 = false;
+  bool isFromF16x2 = false;
+  bool isFromBF16x2 = false;
+
+  bool isRoundingModeRN = getRnd() == NVVM::FPRoundingMode::RN;
+  bool isRoundingModeRZ = getRnd() == NVVM::FPRoundingMode::RZ;
+  bool isRoundingModeRP = getRnd() == NVVM::FPRoundingMode::RP;
+
+  bool isSatFinite = getSat() == NVVM::SaturationMode::SATFINITE;
+
+  bool hasRelu = getRelu();
+
+  if (auto vecType = dyn_cast<VectorType>(getA().getType())) {
+    isFromF16x2 = vecType.getElementType().isF16();
+    isFromBF16x2 = vecType.getElementType().isBF16();
+  } else {
+    isFromF32 = true;
+  }
+
+  if (isFromF32) {
+    if (!(getODSOperands(1).size() > 0))
+      return emitOpError("expected two f32 inputs for converting from f32");
+  } else {
+    if (getODSOperands(1).size() > 0)
+      return emitOpError(
+          "expected only a single f32, vector<2xf16> or vector<2xbf16> input "
+          "for converting from f16x2 or bf16x2, got two inputs instead.");
+  }
+
+  switch (getType()) {
+  case NVVM::CVTFP8Type::E4M3:
+  case NVVM::CVTFP8Type::E5M2:
+    if (!(isFromF32 || isFromF16x2))
+      return emitOpError("expected f32 or f16x2 input for conversions to "
+                         ".e4m3x2 or .e5m2x2 types");
+    if (!isRoundingModeRN)
+      return emitOpError("RN rounding mode required for conversions to .e4m3x2 "
+                         "or .e5m2x2 types");
+    if (!isSatFinite)
+      return emitOpError("SATFINITE saturation mode required for conversions "
+                         "to .e4m3x2 or .e5m2x2 types");
+    break;
+  case NVVM::CVTFP8Type::UE8M0:
+    if (!(isFromF32 || isFromBF16x2))
+      return emitOpError(
+          "expected f32 or bf16x2 input for conversions to .ue8m0x2 type");
+    if (!(isRoundingModeRP || isRoundingModeRZ))
+      return emitOpError(
+          "RP or RZ rounding mode required for conversions to .ue8m0x2 type");
+    if (hasRelu)
+      return emitOpError("relu not supported for conversions to .ue8m0x2 type");
+    break;
+  default:
+    return emitOpError("unsupported FP8 type");
+  }
+
+  return success();
+}
+
 LogicalResult BulkStoreOp::verify() {
   if (getInitVal() != 0)
     return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1304,6 +1367,40 @@ llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
   }
 }
 
+#define CVT_TO_UE8M0X2_IMPL(fromtype, rndm, has_sat)                           \
+  has_sat ? llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm##_satfinite    \
+          : llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm
+
+#define GET_CVT_TO_UE8M0X2_ID(fromtype, rnd, has_sat)                          \
+  (rnd == NVVM::FPRoundingMode::RZ)                                            \
+      ? CVT_TO_UE8M0X2_IMPL(fromtype, _rz, has_sat)                            \
+      : CVT_TO_UE8M0X2_IMPL(fromtype, _rp, has_sat)
+
+#define GET_CVT_TO_F8X2_ID(fromtype, totype, has_relu)                         \
+  has_relu ? llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn_relu          \
+           : llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn
+
+llvm::Intrinsic::ID CvtToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type to,
+                                                bool isFromF32Type,
+                                                NVVM::FPRoundingMode rnd,
+                                                NVVM::SaturationMode sat,
+                                                bool hasRelu) {
+  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
+
+  switch (to) {
+  case NVVM::CVTFP8Type::E4M3:
+    return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e4m3x2, hasRelu)
+                         : GET_CVT_TO_F8X2_ID(f16x2, e4m3x2, hasRelu);
+  case NVVM::CVTFP8Type::E5M2:
+    return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e5m2x2, hasRelu)
+                         : GET_CVT_TO_F8X2_ID(f16x2, e5m2x2, hasRelu);
+  case NVVM::CVTFP8Type::UE8M0:
+    return isFromF32Type ? GET_CVT_TO_UE8M0X2_ID(ff, rnd, hasSatFinite)
+                         : GET_CVT_TO_UE8M0X2_ID(bf16x2, rnd, hasSatFinite);
+  }
+  llvm_unreachable("Invalid CVTFP8Type for CvtToF8x2Op");
+}
+
 llvm::Intrinsic::ID
 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
                                       LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
new file mode 100644
index 0000000000000..06573ce53676f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_fp8x2_packed
+llvm.func @convert_float_to_fp8x2_packed(%srcA : f32, %srcB : f32) -> !llvm.void {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp8x2_vector
+llvm.func @convert_float_to_fp8x2_vector(%srcA : f32, %srcB : f32) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
+  // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
+  %res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> vector<2xi8>
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp8x2_with_relu
+llvm.func @convert_float_to_fp8x2_with_relu(%srcA : f32, %srcB : f32) -> !llvm.void {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp8x2
+llvm.func @convert_f16x2_to_fp8x2(%src : vector<2xf16>) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
+  llvm.return
+}
+
+
+// CHECK-LABEL: @convert_bf16x2_to_fp8x2
+llvm.func @convert_bf16x2_to_fp8x2(%src : vector<2xbf16>) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xbf16> -> vector<2xi8>
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp8x2_with_relu
+llvm.func @convert_f16x2_to_fp8x2_with_relu(%src : vector<2xf16>) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index f87f11daeef54..fc00ea6ee7003 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -176,3 +176,91 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
   %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
   llvm.return
 }
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e4m3(%a : f32, %b : f32) {
+  // expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e5m2(%a : f32, %b : f32) {
+  // expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_ue8m0(%a : f32, %b : f32) {
+  // expected-error @below {{RP or RZ rounding mode required for conversions to .ue8m0x2 type}}
+  %res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e4m3(%a : f32, %b : f32) {
+  // expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e5m2(%a : f32, %b : f32) {
+  // expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) {
+  // expected-error @below {{relu not supported for conversions to .ue8m0x2 type}}
+  %res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e4m3(%src : vector<2xbf16>) {
+  // expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %src : vector<2xbf16> -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e5m2(%src : vector<2xbf16>) {
+  // expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e5m2> %src : vector<2xbf16> -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_ue8m0(%src : vector<2xf16>) {
+  // expected-error @below {{expected f32 or bf16x2 input for conversions to .ue8m0x2 type}}
+  %res = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16> -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_two_inputs_with_fromfp16x2(%src : vector<2xf16>, %b : f32) {
+  // expected-error @below {{expected only a single f32, vector<2xf16> or vector<2xbf16> input for converting from f16x2 or bf16x2, got two inputs instead.}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %src, %b : vector<2xf16>, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_missing_second_input(%a : f32) {
+  // expected-error @below {{expected two f32 inputs for converting from f32}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %a {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : f32 -> i16
+  llvm.return
+}

Copy link

github-actions bot commented Apr 29, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-cvt-f8x2 branch from 4b022aa to d7aebe2 Compare April 29, 2025 10:54
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-cvt-f8x2 branch from d7aebe2 to d1464bb Compare May 2, 2025 10:17
@Wolfram70 Wolfram70 changed the title [MLIR][NVVM] Add support for f8x2 conversion [MLIR][NVVM] Update support for conversions to f8x2 and f6x2 types May 2, 2025
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-cvt-f8x2 branch from d1464bb to 5a694bf Compare May 5, 2025 11:49
Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest changes LGTM.
Would be good to fix the hasVerifier location in the Op though,

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-cvt-f8x2 branch 2 times, most recently from aec7cf4 to f883bb1 Compare May 6, 2025 05:18
This change:
- Adds the `cvt.f32x2.to.f6x2`, `cvt.f16x2.to.f8x2`, and `cvt.bf16x2.to.f8x2`
  Ops to the NVVM dialect for the conversions to `.e4m3x2`, `e5m2x2`,
  and `.ue8m0x2` types.
- Renames the recently added `cvt.to.f6x2` Op to `cvt.f32x2.to.f6x2`
  for consistency with the other conversion Ops.

For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-cvt-f8x2 branch from f883bb1 to c57ad76 Compare May 6, 2025 05:30
@Wolfram70 Wolfram70 merged commit bb2aa1a into llvm:main May 6, 2025
11 checks passed
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…lvm#137781)

This change:
- Adds the `cvt.f32x2.to.f8x2`, `cvt.f16x2.to.f8x2`, and
`cvt.bf16x2.to.f8x2`
  Ops to the NVVM dialect for the conversions to `.e4m3x2`, `e5m2x2`,
  and `.ue8m0x2` types.
- Renames the recently added `cvt.to.f6x2` Op to `cvt.f32x2.to.f6x2`
  for consistency with the other conversion Ops.

For more information, see PTX ISA:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
Comment on lines +1124 to +1126
def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating type for each op won't really scale. I think should use existing types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense. I did it this way here since the other existing type enums were prefixed with the corresponding Op name. But we should probably unify all of them and rename them with an NVVMFP prefix instead perhaps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes Guray, I brought this up on the first fp6 cvt Op review itself (more from a re-use perspective, though)

With a unified enum (let's say, for all the FP types), we may need to update/tighten the verifiers of many Ops to error out on unsupported types. Please let us know your thoughts on this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need enum embedded in NVVM dialect? I'm asking can we just reuse existing MLIR builtin types. At this point, I assume we have all the exotic types. What do you think?

Copy link
Contributor Author

@Wolfram70 Wolfram70 May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried using the builtin types ( f8e4m3fn, f8e5m2, and f8e8m0fnu) for these Ops but ran into issues during lowering to LLVMIR and bitcasting the vector to a packed i16 for the intrinsics since it looks like the vectors of these types cannot be constructed/are supported and an assertion fails due to mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp:821:isCompatibleVectorType().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants