Skip to content

Commit 4b022aa

Browse files
committed
[MLIR][NVVM] Add support for f8x2 conversion
This patch adds the `cvt.to.f8x2` NVVM dialect Op for conversion into f6x2 types. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
1 parent fa1fe11 commit 4b022aa

File tree

4 files changed

+360
-0
lines changed

4 files changed

+360
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,6 +1120,110 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
11201120
}];
11211121
}
11221122

1123+
def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
1124+
def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
1125+
def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
1126+
1127+
def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
1128+
[CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
1129+
let genSpecializedAttr = 0;
1130+
let cppNamespace = "::mlir::NVVM";
1131+
}
1132+
def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
1133+
let assemblyFormat = "`<` $value `>`";
1134+
}
1135+
1136+
def NVVM_CvtToF8x2Op : NVVM_Op<"cvt.to.f8x2"> {
1137+
let summary = "Convert a pair of f32 or fp16 inputs to f8x2";
1138+
let description = [{
1139+
This Op converts each of the given float input types to the specified f8
1140+
type.
1141+
The result `dst` is either represented as an i16 type or a vector
1142+
of two f8 types.
1143+
The following table describes the supported conversions and their formats:
1144+
```
1145+
|-----------|-----------|--------------------------------------------------|
1146+
| Src Type | Dst Type | Description |
1147+
|-----------|-----------|--------------------------------------------------|
1148+
| f16x2 | e4m3x2 | Only operand `a` must be provided and it must |
1149+
| | e5m2x2 | be a vector of two F16s. |
1150+
| | | If `dst` is returned as an i16 type, the |
1151+
| | | converted values are packed such that the |
1152+
| | | value converted from the first element of `a` |
1153+
| | | is stored in the upper 8 bits of `dst` and the |
1154+
| | | value converted from the second element of `a` |
1155+
| | | is stored in the lower 8 bits of `dst`. |
1156+
| | | If `dst` is returned as a vector type, each |
1157+
| | | converted value from `a` is stored as an i8 |
1158+
| | | element in the vector. |
1159+
|-----------|-----------|--------------------------------------------------|
1160+
| bf16x2 | ue8m0x2 | Only operand `a` must be provided and it must |
1161+
| | | be a vector of two BF16s. |
1162+
| | | If `dst` is returned as an i16 type, the |
1163+
| | | converted values are packed such that the |
1164+
| | | value converted from the first element of `a` |
1165+
| | | is stored in the upper 8 bits of `dst` and the |
1166+
| | | value converted from the second element of `a` |
1167+
| | | is stored in the lower 8 bits of `dst`. |
1168+
| | | If `dst` is returned as a vector type, each |
1169+
| | | converted value from `a` is stored as an i8 |
1170+
| | | element in the vector. |
1171+
|-----------|-----------|--------------------------------------------------|
1172+
| f32, f32 | e4m3x2 | Both operands `a` and `b` must be provided and |
1173+
| | e5m2x2 | they must be F32 values. |
1174+
| | ue8m0x2 | If `dst` is returned as an i16 type, the |
1175+
| | | converted values are packed such that the |
1176+
| | | value converted from `a` is stored in the |
1177+
| | | upper 8 bits of `dst` and the value converted |
1178+
| | | from `b` is stored in the lower 8 bits of |
1179+
| | | `dst`. If `dst` is returned as a vector type, |
1180+
| | | each converted value is stored as an i8 |
1181+
| | | element in the vector. |
1182+
|-----------|-----------|--------------------------------------------------|
1183+
```
1184+
The `relu` attribute, when set, lowers to the '.relu' variant of
1185+
the cvt instruction for conversions to the signed f8 types (e4m3 and e5m2).
1186+
The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
1187+
1188+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1189+
}];
1190+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1191+
let arguments = (ins
1192+
CVTFP8TypeAttr:$type,
1193+
AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16, BF16]>]>:$a,
1194+
Optional<F32>:$b,
1195+
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1196+
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
1197+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1198+
let assemblyFormat = "$type $a (`,` $b^)? attr-dict `:` type($a) (`,` type($b)^)? `->` type($dst)";
1199+
1200+
let extraClassDeclaration = [{
1201+
bool isFromF32Type();
1202+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
1203+
bool isFromF32Type,
1204+
NVVM::FPRoundingMode rnd,
1205+
NVVM::SaturationMode sat,
1206+
bool hasRelu);
1207+
}];
1208+
1209+
string llvmBuilder = [{
1210+
auto intId = NVVM::CvtToF8x2Op::getIntrinsicID($type, op.isFromF32Type(), $rnd, $sat, $relu);
1211+
llvm::Value *packedI16;
1212+
if(op.isFromF32Type())
1213+
packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1214+
else
1215+
packedI16 = createIntrinsicCall(builder, intId, {$a});
1216+
if(op.getDst().getType().isInteger(16))
1217+
$dst = packedI16;
1218+
else
1219+
$dst = builder.CreateBitCast(packedI16,
1220+
llvm::FixedVectorType::get(
1221+
llvm::Type::getInt8Ty(builder.getContext()), 2));
1222+
}];
1223+
1224+
let hasVerifier = 1;
1225+
}
1226+
11231227
//===----------------------------------------------------------------------===//
11241228
// NVVM MMA Ops
11251229
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/ADT/TypeSwitch.h"
3333
#include "llvm/AsmParser/Parser.h"
3434
#include "llvm/IR/Attributes.h"
35+
#include "llvm/IR/IRBuilder.h"
3536
#include "llvm/IR/Function.h"
3637
#include "llvm/IR/IntrinsicsNVPTX.h"
3738
#include "llvm/IR/Type.h"
@@ -133,6 +134,68 @@ LogicalResult CvtFloatToTF32Op::verify() {
133134
return success();
134135
}
135136

137+
bool CvtToF8x2Op::isFromF32Type() { return getA().getType().isF32(); }
138+
139+
LogicalResult CvtToF8x2Op::verify() {
140+
bool isFromF32 = false;
141+
bool isFromF16x2 = false;
142+
bool isFromBF16x2 = false;
143+
144+
bool isRoundingModeRN = getRnd() == NVVM::FPRoundingMode::RN;
145+
bool isRoundingModeRZ = getRnd() == NVVM::FPRoundingMode::RZ;
146+
bool isRoundingModeRP = getRnd() == NVVM::FPRoundingMode::RP;
147+
148+
bool isSatFinite = getSat() == NVVM::SaturationMode::SATFINITE;
149+
150+
bool hasRelu = getRelu();
151+
152+
if (auto vecType = dyn_cast<VectorType>(getA().getType())) {
153+
isFromF16x2 = vecType.getElementType().isF16();
154+
isFromBF16x2 = vecType.getElementType().isBF16();
155+
} else {
156+
isFromF32 = true;
157+
}
158+
159+
if (isFromF32) {
160+
if (!(getODSOperands(1).size() > 0))
161+
return emitOpError("expected two f32 inputs for converting from f32");
162+
} else {
163+
if (getODSOperands(1).size() > 0)
164+
return emitOpError(
165+
"expected only a single f32, vector<2xf16> or vector<2xbf16> input "
166+
"for converting from f16x2 or bf16x2, got two inputs instead.");
167+
}
168+
169+
switch (getType()) {
170+
case NVVM::CVTFP8Type::E4M3:
171+
case NVVM::CVTFP8Type::E5M2:
172+
if (!(isFromF32 || isFromF16x2))
173+
return emitOpError("expected f32 or f16x2 input for conversions to "
174+
".e4m3x2 or .e5m2x2 types");
175+
if (!isRoundingModeRN)
176+
return emitOpError("RN rounding mode required for conversions to .e4m3x2 "
177+
"or .e5m2x2 types");
178+
if (!isSatFinite)
179+
return emitOpError("SATFINITE saturation mode required for conversions "
180+
"to .e4m3x2 or .e5m2x2 types");
181+
break;
182+
case NVVM::CVTFP8Type::UE8M0:
183+
if (!(isFromF32 || isFromBF16x2))
184+
return emitOpError(
185+
"expected f32 or bf16x2 input for conversions to .ue8m0x2 type");
186+
if (!(isRoundingModeRP || isRoundingModeRZ))
187+
return emitOpError(
188+
"RP or RZ rounding mode required for conversions to .ue8m0x2 type");
189+
if (hasRelu)
190+
return emitOpError("relu not supported for conversions to .ue8m0x2 type");
191+
break;
192+
default:
193+
return emitOpError("unsupported FP8 type");
194+
}
195+
196+
return success();
197+
}
198+
136199
LogicalResult BulkStoreOp::verify() {
137200
if (getInitVal() != 0)
138201
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1304,6 +1367,40 @@ llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
13041367
}
13051368
}
13061369

1370+
#define CVT_TO_UE8M0X2_IMPL(fromtype, rndm, has_sat) \
1371+
has_sat ? llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm##_satfinite \
1372+
: llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm
1373+
1374+
#define GET_CVT_TO_UE8M0X2_ID(fromtype, rnd, has_sat) \
1375+
(rnd == NVVM::FPRoundingMode::RZ) \
1376+
? CVT_TO_UE8M0X2_IMPL(fromtype, _rz, has_sat) \
1377+
: CVT_TO_UE8M0X2_IMPL(fromtype, _rp, has_sat)
1378+
1379+
#define GET_CVT_TO_F8X2_ID(fromtype, totype, has_relu) \
1380+
has_relu ? llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn_relu \
1381+
: llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn
1382+
1383+
llvm::Intrinsic::ID CvtToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type to,
1384+
bool isFromF32Type,
1385+
NVVM::FPRoundingMode rnd,
1386+
NVVM::SaturationMode sat,
1387+
bool hasRelu) {
1388+
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1389+
1390+
switch (to) {
1391+
case NVVM::CVTFP8Type::E4M3:
1392+
return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e4m3x2, hasRelu)
1393+
: GET_CVT_TO_F8X2_ID(f16x2, e4m3x2, hasRelu);
1394+
case NVVM::CVTFP8Type::E5M2:
1395+
return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e5m2x2, hasRelu)
1396+
: GET_CVT_TO_F8X2_ID(f16x2, e5m2x2, hasRelu);
1397+
case NVVM::CVTFP8Type::UE8M0:
1398+
return isFromF32Type ? GET_CVT_TO_UE8M0X2_ID(ff, rnd, hasSatFinite)
1399+
: GET_CVT_TO_UE8M0X2_ID(bf16x2, rnd, hasSatFinite);
1400+
}
1401+
llvm_unreachable("Invalid CVTFP8Type for CvtToF8x2Op");
1402+
}
1403+
13071404
llvm::Intrinsic::ID
13081405
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
13091406
LLVM::ModuleTranslation &mt,
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: @convert_float_to_fp8x2_packed
4+
llvm.func @convert_float_to_fp8x2_packed(%srcA : f32, %srcB : f32) -> !llvm.void {
5+
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
6+
%res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
7+
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
8+
%res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
9+
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
10+
%res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
11+
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}})
12+
%res4 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
13+
llvm.return
14+
}
15+
16+
// CHECK-LABEL: @convert_float_to_fp8x2_vector
17+
llvm.func @convert_float_to_fp8x2_vector(%srcA : f32, %srcB : f32) -> !llvm.void {
18+
// CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
19+
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
20+
%res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
21+
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
22+
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
23+
%res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
24+
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
25+
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
26+
%res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> vector<2xi8>
27+
llvm.return
28+
}
29+
30+
// CHECK-LABEL: @convert_float_to_fp8x2_with_relu
31+
llvm.func @convert_float_to_fp8x2_with_relu(%srcA : f32, %srcB : f32) -> !llvm.void {
32+
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
33+
%res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
34+
// CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}})
35+
%res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
36+
llvm.return
37+
}
38+
39+
// CHECK-LABEL: @convert_f16x2_to_fp8x2
40+
llvm.func @convert_f16x2_to_fp8x2(%src : vector<2xf16>) -> !llvm.void {
41+
// CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
42+
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
43+
%res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
44+
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
45+
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
46+
%res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
47+
llvm.return
48+
}
49+
50+
51+
// CHECK-LABEL: @convert_bf16x2_to_fp8x2
52+
llvm.func @convert_bf16x2_to_fp8x2(%src : vector<2xbf16>) -> !llvm.void {
53+
// CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
54+
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
55+
%res1 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> vector<2xi8>
56+
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}})
57+
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
58+
%res2 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xbf16> -> vector<2xi8>
59+
llvm.return
60+
}
61+
62+
// CHECK-LABEL: @convert_f16x2_to_fp8x2_with_relu
63+
llvm.func @convert_f16x2_to_fp8x2_with_relu(%src : vector<2xf16>) -> !llvm.void {
64+
// CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}})
65+
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
66+
%res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
67+
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}})
68+
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
69+
%res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
70+
llvm.return
71+
}

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,91 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
176176
%0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
177177
llvm.return
178178
}
179+
180+
// -----
181+
182+
llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e4m3(%a : f32, %b : f32) {
183+
// expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
184+
%res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
185+
llvm.return
186+
}
187+
188+
// -----
189+
190+
llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e5m2(%a : f32, %b : f32) {
191+
// expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
192+
%res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
193+
llvm.return
194+
}
195+
196+
// -----
197+
198+
llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_ue8m0(%a : f32, %b : f32) {
199+
// expected-error @below {{RP or RZ rounding mode required for conversions to .ue8m0x2 type}}
200+
%res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
201+
llvm.return
202+
}
203+
204+
// -----
205+
206+
llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e4m3(%a : f32, %b : f32) {
207+
// expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
208+
%res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
209+
llvm.return
210+
}
211+
212+
// -----
213+
214+
llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e5m2(%a : f32, %b : f32) {
215+
// expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
216+
%res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
217+
llvm.return
218+
}
219+
220+
// -----
221+
222+
llvm.func @nvvm_cvt_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) {
223+
// expected-error @below {{relu not supported for conversions to .ue8m0x2 type}}
224+
%res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : f32, f32 -> i16
225+
llvm.return
226+
}
227+
228+
// -----
229+
230+
llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e4m3(%src : vector<2xbf16>) {
231+
// expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
232+
%res = nvvm.cvt.to.f8x2 <e4m3> %src : vector<2xbf16> -> i16
233+
llvm.return
234+
}
235+
236+
// -----
237+
238+
llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e5m2(%src : vector<2xbf16>) {
239+
// expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
240+
%res = nvvm.cvt.to.f8x2 <e5m2> %src : vector<2xbf16> -> i16
241+
llvm.return
242+
}
243+
244+
// -----
245+
246+
llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_ue8m0(%src : vector<2xf16>) {
247+
// expected-error @below {{expected f32 or bf16x2 input for conversions to .ue8m0x2 type}}
248+
%res = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16> -> i16
249+
llvm.return
250+
}
251+
252+
// -----
253+
254+
llvm.func @nvvm_cvt_to_f8x2_two_inputs_with_fromfp16x2(%src : vector<2xf16>, %b : f32) {
255+
// expected-error @below {{expected only a single f32, vector<2xf16> or vector<2xbf16> input for converting from f16x2 or bf16x2, got two inputs instead.}}
256+
%res = nvvm.cvt.to.f8x2 <e4m3> %src, %b : vector<2xf16>, f32 -> i16
257+
llvm.return
258+
}
259+
260+
// -----
261+
262+
llvm.func @nvvm_cvt_to_f8x2_missing_second_input(%a : f32) {
263+
// expected-error @below {{expected two f32 inputs for converting from f32}}
264+
%res = nvvm.cvt.to.f8x2 <e4m3> %a {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : f32 -> i16
265+
llvm.return
266+
}

0 commit comments

Comments
 (0)