-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][NVVM] Update support for conversions to f8x2 and f6x2 types #137781
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Update support for conversions to f8x2 and f6x2 types #137781
Conversation
@llvm/pr-subscribers-mlir Author: Srinivasa Ravi (Wolfram70) ChangesThis patch adds the For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt Full diff: https://github.com/llvm/llvm-project/pull/137781.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 27d54e7abeda9..f5eb91bc029f5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1120,6 +1120,110 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
}];
}
+def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
+def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
+def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
+
+def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
+ [CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtToF8x2Op : NVVM_Op<"cvt.to.f8x2"> {
+ let summary = "Convert a pair of f32 or fp16 inputs to f8x2";
+ let description = [{
+ This Op converts each of the given float input types to the specified f8
+ type.
+ The result `dst` is either represented as an i16 type or a vector
+ of two f8 types.
+ The following table describes the supported conversions and their formats:
+ ```
+ |-----------|-----------|--------------------------------------------------|
+ | Src Type | Dst Type | Description |
+ |-----------|-----------|--------------------------------------------------|
+ | f16x2 | e4m3x2 | Only operand `a` must be provided and it must |
+ | | e5m2x2 | be a vector of two F16s. |
+ | | | If `dst` is returned as an i16 type, the |
+ | | | converted values are packed such that the |
+ | | | value converted from the first element of `a` |
+ | | | is stored in the upper 8 bits of `dst` and the |
+ | | | value converted from the second element of `a` |
+ | | | is stored in the lower 8 bits of `dst`. |
+ | | | If `dst` is returned as a vector type, each |
+ | | | converted value from `a` is stored as an i8 |
+ | | | element in the vector. |
+ |-----------|-----------|--------------------------------------------------|
+ | bf16x2 | ue8m0x2 | Only operand `a` must be provided and it must |
+ | | | be a vector of two BF16s. |
+ | | | If `dst` is returned as an i16 type, the |
+ | | | converted values are packed such that the |
+ | | | value converted from the first element of `a` |
+ | | | is stored in the upper 8 bits of `dst` and the |
+ | | | value converted from the second element of `a` |
+ | | | is stored in the lower 8 bits of `dst`. |
+ | | | If `dst` is returned as a vector type, each |
+ | | | converted value from `a` is stored as an i8 |
+ | | | element in the vector. |
+ |-----------|-----------|--------------------------------------------------|
+ | f32, f32 | e4m3x2 | Both operands `a` and `b` must be provided and |
+ | | e5m2x2 | they must be F32 values. |
+ | | ue8m0x2 | If `dst` is returned as an i16 type, the |
+ | | | converted values are packed such that the |
+ | | | value converted from `a` is stored in the |
+ | | | upper 8 bits of `dst` and the value converted |
+ | | | from `b` is stored in the lower 8 bits of |
+ | | | `dst`. If `dst` is returned as a vector type, |
+ | | | each converted value is stored as an i8 |
+ | | | element in the vector. |
+ |-----------|-----------|--------------------------------------------------|
+ ```
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction for conversions to the signed f8 types (e4m3 and e5m2).
+ The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins
+ CVTFP8TypeAttr:$type,
+ AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16, BF16]>]>:$a,
+ Optional<F32>:$b,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat = "$type $a (`,` $b^)? attr-dict `:` type($a) (`,` type($b)^)? `->` type($dst)";
+
+ let extraClassDeclaration = [{
+ bool isFromF32Type();
+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
+ bool isFromF32Type,
+ NVVM::FPRoundingMode rnd,
+ NVVM::SaturationMode sat,
+ bool hasRelu);
+ }];
+
+ string llvmBuilder = [{
+ auto intId = NVVM::CvtToF8x2Op::getIntrinsicID($type, op.isFromF32Type(), $rnd, $sat, $relu);
+ llvm::Value *packedI16;
+ if(op.isFromF32Type())
+ packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+ else
+ packedI16 = createIntrinsicCall(builder, intId, {$a});
+ if(op.getDst().getType().isInteger(16))
+ $dst = packedI16;
+ else
+ $dst = builder.CreateBitCast(packedI16,
+ llvm::FixedVectorType::get(
+ llvm::Type::getInt8Ty(builder.getContext()), 2));
+ }];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 18453aa7f6ea9..c30c45abbdd02 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -32,6 +32,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Attributes.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/Type.h"
@@ -133,6 +134,68 @@ LogicalResult CvtFloatToTF32Op::verify() {
return success();
}
+bool CvtToF8x2Op::isFromF32Type() { return getA().getType().isF32(); }
+
+LogicalResult CvtToF8x2Op::verify() {
+ bool isFromF32 = false;
+ bool isFromF16x2 = false;
+ bool isFromBF16x2 = false;
+
+ bool isRoundingModeRN = getRnd() == NVVM::FPRoundingMode::RN;
+ bool isRoundingModeRZ = getRnd() == NVVM::FPRoundingMode::RZ;
+ bool isRoundingModeRP = getRnd() == NVVM::FPRoundingMode::RP;
+
+ bool isSatFinite = getSat() == NVVM::SaturationMode::SATFINITE;
+
+ bool hasRelu = getRelu();
+
+ if (auto vecType = dyn_cast<VectorType>(getA().getType())) {
+ isFromF16x2 = vecType.getElementType().isF16();
+ isFromBF16x2 = vecType.getElementType().isBF16();
+ } else {
+ isFromF32 = true;
+ }
+
+ if (isFromF32) {
+ if (!(getODSOperands(1).size() > 0))
+ return emitOpError("expected two f32 inputs for converting from f32");
+ } else {
+ if (getODSOperands(1).size() > 0)
+ return emitOpError(
+ "expected only a single f32, vector<2xf16> or vector<2xbf16> input "
+ "for converting from f16x2 or bf16x2, got two inputs instead.");
+ }
+
+ switch (getType()) {
+ case NVVM::CVTFP8Type::E4M3:
+ case NVVM::CVTFP8Type::E5M2:
+ if (!(isFromF32 || isFromF16x2))
+ return emitOpError("expected f32 or f16x2 input for conversions to "
+ ".e4m3x2 or .e5m2x2 types");
+ if (!isRoundingModeRN)
+ return emitOpError("RN rounding mode required for conversions to .e4m3x2 "
+ "or .e5m2x2 types");
+ if (!isSatFinite)
+ return emitOpError("SATFINITE saturation mode required for conversions "
+ "to .e4m3x2 or .e5m2x2 types");
+ break;
+ case NVVM::CVTFP8Type::UE8M0:
+ if (!(isFromF32 || isFromBF16x2))
+ return emitOpError(
+ "expected f32 or bf16x2 input for conversions to .ue8m0x2 type");
+ if (!(isRoundingModeRP || isRoundingModeRZ))
+ return emitOpError(
+ "RP or RZ rounding mode required for conversions to .ue8m0x2 type");
+ if (hasRelu)
+ return emitOpError("relu not supported for conversions to .ue8m0x2 type");
+ break;
+ default:
+ return emitOpError("unsupported FP8 type");
+ }
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1304,6 +1367,40 @@ llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
}
}
+#define CVT_TO_UE8M0X2_IMPL(fromtype, rndm, has_sat) \
+ has_sat ? llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm##_satfinite \
+ : llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm
+
+#define GET_CVT_TO_UE8M0X2_ID(fromtype, rnd, has_sat) \
+ (rnd == NVVM::FPRoundingMode::RZ) \
+ ? CVT_TO_UE8M0X2_IMPL(fromtype, _rz, has_sat) \
+ : CVT_TO_UE8M0X2_IMPL(fromtype, _rp, has_sat)
+
+#define GET_CVT_TO_F8X2_ID(fromtype, totype, has_relu) \
+ has_relu ? llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn_relu \
+ : llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn
+
+llvm::Intrinsic::ID CvtToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type to,
+ bool isFromF32Type,
+ NVVM::FPRoundingMode rnd,
+ NVVM::SaturationMode sat,
+ bool hasRelu) {
+ bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
+
+ switch (to) {
+ case NVVM::CVTFP8Type::E4M3:
+ return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e4m3x2, hasRelu)
+ : GET_CVT_TO_F8X2_ID(f16x2, e4m3x2, hasRelu);
+ case NVVM::CVTFP8Type::E5M2:
+ return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e5m2x2, hasRelu)
+ : GET_CVT_TO_F8X2_ID(f16x2, e5m2x2, hasRelu);
+ case NVVM::CVTFP8Type::UE8M0:
+ return isFromF32Type ? GET_CVT_TO_UE8M0X2_ID(ff, rnd, hasSatFinite)
+ : GET_CVT_TO_UE8M0X2_ID(bf16x2, rnd, hasSatFinite);
+ }
+ llvm_unreachable("Invalid CVTFP8Type for CvtToF8x2Op");
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
new file mode 100644
index 0000000000000..06573ce53676f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_fp8x2_packed
+llvm.func @convert_float_to_fp8x2_packed(%srcA : f32, %srcB : f32) -> !llvm.void {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp8x2_vector
+llvm.func @convert_float_to_fp8x2_vector(%srcA : f32, %srcB : f32) -> !llvm.void {
+ // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+ %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
+ %res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> vector<2xi8>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp8x2_with_relu
+llvm.func @convert_float_to_fp8x2_with_relu(%srcA : f32, %srcB : f32) -> !llvm.void {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp8x2
+llvm.func @convert_f16x2_to_fp8x2(%src : vector<2xf16>) -> !llvm.void {
+ // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+ %res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
+ llvm.return
+}
+
+
+// CHECK-LABEL: @convert_bf16x2_to_fp8x2
+llvm.func @convert_bf16x2_to_fp8x2(%src : vector<2xbf16>) -> !llvm.void {
+ // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+ %res1 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> vector<2xi8>
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res2 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xbf16> -> vector<2xi8>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp8x2_with_relu
+llvm.func @convert_f16x2_to_fp8x2_with_relu(%src : vector<2xf16>) -> !llvm.void {
+ // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+ %res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index f87f11daeef54..fc00ea6ee7003 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -176,3 +176,91 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
%0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
llvm.return
}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e4m3(%a : f32, %b : f32) {
+ // expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e5m2(%a : f32, %b : f32) {
+ // expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_ue8m0(%a : f32, %b : f32) {
+ // expected-error @below {{RP or RZ rounding mode required for conversions to .ue8m0x2 type}}
+ %res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e4m3(%a : f32, %b : f32) {
+ // expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e5m2(%a : f32, %b : f32) {
+ // expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) {
+ // expected-error @below {{relu not supported for conversions to .ue8m0x2 type}}
+ %res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e4m3(%src : vector<2xbf16>) {
+ // expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e4m3> %src : vector<2xbf16> -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e5m2(%src : vector<2xbf16>) {
+ // expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e5m2> %src : vector<2xbf16> -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_ue8m0(%src : vector<2xf16>) {
+ // expected-error @below {{expected f32 or bf16x2 input for conversions to .ue8m0x2 type}}
+ %res = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16> -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_two_inputs_with_fromfp16x2(%src : vector<2xf16>, %b : f32) {
+ // expected-error @below {{expected only a single f32, vector<2xf16> or vector<2xbf16> input for converting from f16x2 or bf16x2, got two inputs instead.}}
+ %res = nvvm.cvt.to.f8x2 <e4m3> %src, %b : vector<2xf16>, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_missing_second_input(%a : f32) {
+ // expected-error @below {{expected two f32 inputs for converting from f32}}
+ %res = nvvm.cvt.to.f8x2 <e4m3> %a {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : f32 -> i16
+ llvm.return
+}
|
@llvm/pr-subscribers-mlir-llvm Author: Srinivasa Ravi (Wolfram70) ChangesThis patch adds the For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt Full diff: https://github.com/llvm/llvm-project/pull/137781.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 27d54e7abeda9..f5eb91bc029f5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1120,6 +1120,110 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
}];
}
+def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
+def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
+def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
+
+def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
+ [CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtToF8x2Op : NVVM_Op<"cvt.to.f8x2"> {
+ let summary = "Convert a pair of f32 or fp16 inputs to f8x2";
+ let description = [{
+ This Op converts each of the given float input types to the specified f8
+ type.
+ The result `dst` is either represented as an i16 type or a vector
+ of two f8 types.
+ The following table describes the supported conversions and their formats:
+ ```
+ |-----------|-----------|--------------------------------------------------|
+ | Src Type | Dst Type | Description |
+ |-----------|-----------|--------------------------------------------------|
+ | f16x2 | e4m3x2 | Only operand `a` must be provided and it must |
+ | | e5m2x2 | be a vector of two F16s. |
+ | | | If `dst` is returned as an i16 type, the |
+ | | | converted values are packed such that the |
+ | | | value converted from the first element of `a` |
+ | | | is stored in the upper 8 bits of `dst` and the |
+ | | | value converted from the second element of `a` |
+ | | | is stored in the lower 8 bits of `dst`. |
+ | | | If `dst` is returned as a vector type, each |
+ | | | converted value from `a` is stored as an i8 |
+ | | | element in the vector. |
+ |-----------|-----------|--------------------------------------------------|
+ | bf16x2 | ue8m0x2 | Only operand `a` must be provided and it must |
+ | | | be a vector of two BF16s. |
+ | | | If `dst` is returned as an i16 type, the |
+ | | | converted values are packed such that the |
+ | | | value converted from the first element of `a` |
+ | | | is stored in the upper 8 bits of `dst` and the |
+ | | | value converted from the second element of `a` |
+ | | | is stored in the lower 8 bits of `dst`. |
+ | | | If `dst` is returned as a vector type, each |
+ | | | converted value from `a` is stored as an i8 |
+ | | | element in the vector. |
+ |-----------|-----------|--------------------------------------------------|
+ | f32, f32 | e4m3x2 | Both operands `a` and `b` must be provided and |
+ | | e5m2x2 | they must be F32 values. |
+ | | ue8m0x2 | If `dst` is returned as an i16 type, the |
+ | | | converted values are packed such that the |
+ | | | value converted from `a` is stored in the |
+ | | | upper 8 bits of `dst` and the value converted |
+ | | | from `b` is stored in the lower 8 bits of |
+ | | | `dst`. If `dst` is returned as a vector type, |
+ | | | each converted value is stored as an i8 |
+ | | | element in the vector. |
+ |-----------|-----------|--------------------------------------------------|
+ ```
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction for conversions to the signed f8 types (e4m3 and e5m2).
+ The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins
+ CVTFP8TypeAttr:$type,
+ AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16, BF16]>]>:$a,
+ Optional<F32>:$b,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat = "$type $a (`,` $b^)? attr-dict `:` type($a) (`,` type($b)^)? `->` type($dst)";
+
+ let extraClassDeclaration = [{
+ bool isFromF32Type();
+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
+ bool isFromF32Type,
+ NVVM::FPRoundingMode rnd,
+ NVVM::SaturationMode sat,
+ bool hasRelu);
+ }];
+
+ string llvmBuilder = [{
+ auto intId = NVVM::CvtToF8x2Op::getIntrinsicID($type, op.isFromF32Type(), $rnd, $sat, $relu);
+ llvm::Value *packedI16;
+ if(op.isFromF32Type())
+ packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+ else
+ packedI16 = createIntrinsicCall(builder, intId, {$a});
+ if(op.getDst().getType().isInteger(16))
+ $dst = packedI16;
+ else
+ $dst = builder.CreateBitCast(packedI16,
+ llvm::FixedVectorType::get(
+ llvm::Type::getInt8Ty(builder.getContext()), 2));
+ }];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 18453aa7f6ea9..c30c45abbdd02 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -32,6 +32,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Attributes.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/Type.h"
@@ -133,6 +134,68 @@ LogicalResult CvtFloatToTF32Op::verify() {
return success();
}
+bool CvtToF8x2Op::isFromF32Type() { return getA().getType().isF32(); }
+
+LogicalResult CvtToF8x2Op::verify() {
+ bool isFromF32 = false;
+ bool isFromF16x2 = false;
+ bool isFromBF16x2 = false;
+
+ bool isRoundingModeRN = getRnd() == NVVM::FPRoundingMode::RN;
+ bool isRoundingModeRZ = getRnd() == NVVM::FPRoundingMode::RZ;
+ bool isRoundingModeRP = getRnd() == NVVM::FPRoundingMode::RP;
+
+ bool isSatFinite = getSat() == NVVM::SaturationMode::SATFINITE;
+
+ bool hasRelu = getRelu();
+
+ if (auto vecType = dyn_cast<VectorType>(getA().getType())) {
+ isFromF16x2 = vecType.getElementType().isF16();
+ isFromBF16x2 = vecType.getElementType().isBF16();
+ } else {
+ isFromF32 = true;
+ }
+
+ if (isFromF32) {
+ if (!(getODSOperands(1).size() > 0))
+ return emitOpError("expected two f32 inputs for converting from f32");
+ } else {
+ if (getODSOperands(1).size() > 0)
+ return emitOpError(
+ "expected only a single f32, vector<2xf16> or vector<2xbf16> input "
+ "for converting from f16x2 or bf16x2, got two inputs instead.");
+ }
+
+ switch (getType()) {
+ case NVVM::CVTFP8Type::E4M3:
+ case NVVM::CVTFP8Type::E5M2:
+ if (!(isFromF32 || isFromF16x2))
+ return emitOpError("expected f32 or f16x2 input for conversions to "
+ ".e4m3x2 or .e5m2x2 types");
+ if (!isRoundingModeRN)
+ return emitOpError("RN rounding mode required for conversions to .e4m3x2 "
+ "or .e5m2x2 types");
+ if (!isSatFinite)
+ return emitOpError("SATFINITE saturation mode required for conversions "
+ "to .e4m3x2 or .e5m2x2 types");
+ break;
+ case NVVM::CVTFP8Type::UE8M0:
+ if (!(isFromF32 || isFromBF16x2))
+ return emitOpError(
+ "expected f32 or bf16x2 input for conversions to .ue8m0x2 type");
+ if (!(isRoundingModeRP || isRoundingModeRZ))
+ return emitOpError(
+ "RP or RZ rounding mode required for conversions to .ue8m0x2 type");
+ if (hasRelu)
+ return emitOpError("relu not supported for conversions to .ue8m0x2 type");
+ break;
+ default:
+ return emitOpError("unsupported FP8 type");
+ }
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1304,6 +1367,40 @@ llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
}
}
+#define CVT_TO_UE8M0X2_IMPL(fromtype, rndm, has_sat) \
+ has_sat ? llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm##_satfinite \
+ : llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm
+
+#define GET_CVT_TO_UE8M0X2_ID(fromtype, rnd, has_sat) \
+ (rnd == NVVM::FPRoundingMode::RZ) \
+ ? CVT_TO_UE8M0X2_IMPL(fromtype, _rz, has_sat) \
+ : CVT_TO_UE8M0X2_IMPL(fromtype, _rp, has_sat)
+
+#define GET_CVT_TO_F8X2_ID(fromtype, totype, has_relu) \
+ has_relu ? llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn_relu \
+ : llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn
+
+llvm::Intrinsic::ID CvtToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type to,
+ bool isFromF32Type,
+ NVVM::FPRoundingMode rnd,
+ NVVM::SaturationMode sat,
+ bool hasRelu) {
+ bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
+
+ switch (to) {
+ case NVVM::CVTFP8Type::E4M3:
+ return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e4m3x2, hasRelu)
+ : GET_CVT_TO_F8X2_ID(f16x2, e4m3x2, hasRelu);
+ case NVVM::CVTFP8Type::E5M2:
+ return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e5m2x2, hasRelu)
+ : GET_CVT_TO_F8X2_ID(f16x2, e5m2x2, hasRelu);
+ case NVVM::CVTFP8Type::UE8M0:
+ return isFromF32Type ? GET_CVT_TO_UE8M0X2_ID(ff, rnd, hasSatFinite)
+ : GET_CVT_TO_UE8M0X2_ID(bf16x2, rnd, hasSatFinite);
+ }
+ llvm_unreachable("Invalid CVTFP8Type for CvtToF8x2Op");
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
new file mode 100644
index 0000000000000..06573ce53676f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_fp8x2_packed
+llvm.func @convert_float_to_fp8x2_packed(%srcA : f32, %srcB : f32) -> !llvm.void {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
+ %res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}})
+ %res4 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp8x2_vector
+llvm.func @convert_float_to_fp8x2_vector(%srcA : f32, %srcB : f32) -> !llvm.void {
+ // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+ %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
+ %res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> vector<2xi8>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp8x2_with_relu
+llvm.func @convert_float_to_fp8x2_with_relu(%srcA : f32, %srcB : f32) -> !llvm.void {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp8x2
+llvm.func @convert_f16x2_to_fp8x2(%src : vector<2xf16>) -> !llvm.void {
+ // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+ %res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
+ llvm.return
+}
+
+
+// CHECK-LABEL: @convert_bf16x2_to_fp8x2
+llvm.func @convert_bf16x2_to_fp8x2(%src : vector<2xbf16>) -> !llvm.void {
+ // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+ %res1 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> vector<2xi8>
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res2 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xbf16> -> vector<2xi8>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp8x2_with_relu
+llvm.func @convert_f16x2_to_fp8x2_with_relu(%src : vector<2xf16>) -> !llvm.void {
+ // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+ %res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index f87f11daeef54..fc00ea6ee7003 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -176,3 +176,91 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
%0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
llvm.return
}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e4m3(%a : f32, %b : f32) {
+ // expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e5m2(%a : f32, %b : f32) {
+ // expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_ue8m0(%a : f32, %b : f32) {
+ // expected-error @below {{RP or RZ rounding mode required for conversions to .ue8m0x2 type}}
+ %res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e4m3(%a : f32, %b : f32) {
+ // expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e5m2(%a : f32, %b : f32) {
+ // expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) {
+ // expected-error @below {{relu not supported for conversions to .ue8m0x2 type}}
+ %res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : f32, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e4m3(%src : vector<2xbf16>) {
+ // expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e4m3> %src : vector<2xbf16> -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e5m2(%src : vector<2xbf16>) {
+ // expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
+ %res = nvvm.cvt.to.f8x2 <e5m2> %src : vector<2xbf16> -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_ue8m0(%src : vector<2xf16>) {
+ // expected-error @below {{expected f32 or bf16x2 input for conversions to .ue8m0x2 type}}
+ %res = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16> -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_two_inputs_with_fromfp16x2(%src : vector<2xf16>, %b : f32) {
+ // expected-error @below {{expected only a single f32, vector<2xf16> or vector<2xbf16> input for converting from f16x2 or bf16x2, got two inputs instead.}}
+ %res = nvvm.cvt.to.f8x2 <e4m3> %src, %b : vector<2xf16>, f32 -> i16
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_missing_second_input(%a : f32) {
+ // expected-error @below {{expected two f32 inputs for converting from f32}}
+ %res = nvvm.cvt.to.f8x2 <e4m3> %a {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : f32 -> i16
+ llvm.return
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
4b022aa
to
d7aebe2
Compare
d7aebe2
to
d1464bb
Compare
d1464bb
to
5a694bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latest changes LGTM.
Would be good to fix the hasVerifier location in the Op though,
aec7cf4
to
f883bb1
Compare
This change: - Adds the `cvt.f32x2.to.f6x2`, `cvt.f16x2.to.f8x2`, and `cvt.bf16x2.to.f8x2` Ops to the NVVM dialect for the conversions to `.e4m3x2`, `e5m2x2`, and `.ue8m0x2` types. - Renames the recently added `cvt.to.f6x2` Op to `cvt.f32x2.to.f6x2` for consistency with the other conversion Ops. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
f883bb1
to
c57ad76
Compare
…lvm#137781) This change: - Adds the `cvt.f32x2.to.f8x2`, `cvt.f16x2.to.f8x2`, and `cvt.bf16x2.to.f8x2` Ops to the NVVM dialect for the conversions to `.e4m3x2`, `e5m2x2`, and `.ue8m0x2` types. - Renames the recently added `cvt.to.f6x2` Op to `cvt.f32x2.to.f6x2` for consistency with the other conversion Ops. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">; | ||
def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">; | ||
def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating type for each op won't really scale. I think should use existing types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that makes sense. I did it this way here since the other existing type enums were prefixed with the corresponding Op name. But we should probably unify all of them and rename them with an NVVMFP
prefix instead perhaps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes Guray, I brought this up on the first fp6 cvt Op review itself (more from a re-use perspective, though)
With a unified enum (let's say, for all the FP types), we may need to update/tighten the verifiers
of many Ops to error out on unsupported types. Please let us know your thoughts on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need enum embedded in NVVM dialect? I'm asking can we just reuse existing MLIR builtin types. At this point, I assume we have all the exotic types. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually tried using the builtin types ( f8e4m3fn
, f8e5m2
, and f8e8m0fnu
) for these Ops but ran into issues during lowering to LLVMIR and bitcasting the vector to a packed i16 for the intrinsics since it looks like the vectors of these types cannot be constructed/are supported and an assertion fails due to mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp:821:isCompatibleVectorType()
.
This change:
cvt.f32x2.to.f8x2
,cvt.f16x2.to.f8x2
, andcvt.bf16x2.to.f8x2
Ops to the NVVM dialect for the conversions to
.e4m3x2
,e5m2x2
,and
.ue8m0x2
types.cvt.to.f6x2
Op tocvt.f32x2.to.f6x2
for consistency with the other conversion Ops.
For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt