Skip to content

Commit aec7cf4

Browse files
committed
[MLIR][NVVM] Update support for conversions to f8x2 and f6x2 types
This change: - Adds the `cvt.f32x2.to.f6x2`, `cvt.f16x2.to.f8x2`, and `cvt.bf16x2.to.f8x2` Ops to the NVVM dialect for the conversions to `.e4m3x2`, `e5m2x2`, and `.ue8m0x2` types. - Renames the recently added `cvt.to.f6x2` Op to `cvt.f32x2.to.f6x2` for consistency with the other conversion Ops. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
1 parent fa1fe11 commit aec7cf4

File tree

5 files changed

+455
-16
lines changed

5 files changed

+455
-16
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 149 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
10791079
let assemblyFormat = "`<` $value `>`";
10801080
}
10811081

1082-
def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1082+
def NVVM_CvtF32x2ToF6x2Op : NVVM_Op<"cvt.f32x2.to.f6x2"> {
10831083
let summary = "Convert a pair of float inputs to f6x2";
10841084
let description = [{
10851085
This Op converts each of the given float inputs to the specified fp6 type.
@@ -1110,7 +1110,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
11101110
}];
11111111

11121112
string llvmBuilder = [{
1113-
auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
1113+
auto intId = NVVM::CvtF32x2ToF6x2Op::getIntrinsicID($type, $relu);
11141114
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
11151115
if(op.getDst().getType().isInteger(16))
11161116
$dst = packedI16;
@@ -1120,6 +1120,153 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
11201120
}];
11211121
}
11221122

1123+
def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
1124+
def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
1125+
def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
1126+
1127+
def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
1128+
[CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
1129+
let genSpecializedAttr = 0;
1130+
let cppNamespace = "::mlir::NVVM";
1131+
}
1132+
def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
1133+
let assemblyFormat = "`<` $value `>`";
1134+
}
1135+
1136+
def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> {
1137+
let summary = "Convert a pair of float inputs to f8x2";
1138+
let description = [{
1139+
This Op converts each of the given float inputs to the specified fp8 type.
1140+
The result `dst` is represented as an i16 type or as a vector
1141+
of two i8 types.
1142+
If `dst` is returned as an i16 type, the converted values are packed such
1143+
that the value converted from `a` is stored in the upper 8 bits of `dst`
1144+
and the value converted from `b` is stored in the lower 8 bits of `dst`.
1145+
If `dst` is returned as a vector type, each converted value is stored as an
1146+
i8 element in the vector.
1147+
The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
1148+
The `relu` attribute, when set, lowers to the '.relu' variant of
1149+
the cvt instruction.
1150+
1151+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1152+
}];
1153+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1154+
let arguments = (ins
1155+
CVTFP8TypeAttr:$type,
1156+
F32:$a,
1157+
F32:$b,
1158+
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1159+
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
1160+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1161+
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
1162+
1163+
let extraClassDeclaration = [{
1164+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
1165+
NVVM::FPRoundingMode rnd,
1166+
NVVM::SaturationMode sat,
1167+
bool hasRelu);
1168+
}];
1169+
1170+
string llvmBuilder = [{
1171+
auto intId = NVVM::CvtF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
1172+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1173+
if(op.getDst().getType().isInteger(16))
1174+
$dst = packedI16;
1175+
else
1176+
$dst = builder.CreateBitCast(packedI16,
1177+
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1178+
}];
1179+
1180+
let hasVerifier = 1;
1181+
}
1182+
1183+
def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> {
1184+
let summary = "Convert an f16x2 input to f8x2";
1185+
let description = [{
1186+
This Op converts the given f16 inputs in an f16x2 vector to the specified
1187+
f8 type.
1188+
The result `dst` is represented as an i16 type or as a vector
1189+
of two i8 types.
1190+
If `dst` is returned as an i16 type, the converted values from `a`
1191+
are packed such that the value converted from the first element of `a`
1192+
is stored in the upper 8 bits of `dst` and the value converted from the
1193+
second element of `a` is stored in the lower 8 bits of `dst`.
1194+
If `dst` is returned as a vector type, each converted value is stored as an
1195+
i8 element in the vector.
1196+
The `relu` attribute, when set, lowers to the '.relu' variant of
1197+
the cvt instruction.
1198+
1199+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1200+
}];
1201+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1202+
let arguments = (ins
1203+
CVTFP8TypeAttr:$type,
1204+
VectorOfLengthAndType<[2], [F16]>:$a,
1205+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1206+
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1207+
1208+
let extraClassDeclaration = [{
1209+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
1210+
bool hasRelu);
1211+
}];
1212+
1213+
string llvmBuilder = [{
1214+
auto intId = NVVM::CvtF16x2ToF8x2Op::getIntrinsicID($type, $relu);
1215+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
1216+
if(op.getDst().getType().isInteger(16))
1217+
$dst = packedI16;
1218+
else
1219+
$dst = builder.CreateBitCast(packedI16,
1220+
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1221+
}];
1222+
1223+
let hasVerifier = 1;
1224+
}
1225+
1226+
def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> {
1227+
let summary = "Convert a pair of bf16 inputs to f8x2";
1228+
let description = [{
1229+
This Op converts the given bf16 inputs in a bf16x2 vector to the specified
1230+
f8 type.
1231+
The result `dst` is represented as an i16 type or as a vector
1232+
of two i8 types.
1233+
If `dst` is returned as an i16 type, the converted values from `a`
1234+
are packed such that the value converted from the first element of `a`
1235+
is stored in the upper 8 bits of `dst` and the value converted from the
1236+
second element of `a` is stored in the lower 8 bits of `dst`.
1237+
If `dst` is returned as a vector type, each converted value is stored as an
1238+
i8 element in the vector.
1239+
The `rnd` and `sat` attributes specify the rounding and saturation modes
1240+
respectively.
1241+
1242+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1243+
}];
1244+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1245+
let arguments = (ins
1246+
CVTFP8TypeAttr:$type,
1247+
VectorOfLengthAndType<[2], [BF16]>:$a,
1248+
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1249+
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
1250+
let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1251+
1252+
let extraClassDeclaration = [{
1253+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
1254+
NVVM::SaturationMode sat);
1255+
}];
1256+
1257+
string llvmBuilder = [{
1258+
auto intId = NVVM::CvtBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
1259+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
1260+
if(op.getDst().getType().isInteger(16))
1261+
$dst = packedI16;
1262+
else
1263+
$dst = builder.CreateBitCast(packedI16,
1264+
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1265+
}];
1266+
1267+
let hasVerifier = 1;
1268+
}
1269+
11231270
//===----------------------------------------------------------------------===//
11241271
// NVVM MMA Ops
11251272
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 124 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,61 @@ LogicalResult CvtFloatToTF32Op::verify() {
133133
return success();
134134
}
135135

136+
LogicalResult CvtF32x2ToF8x2Op::verify() {
137+
using RndMode = NVVM::FPRoundingMode;
138+
using SatMode = NVVM::SaturationMode;
139+
140+
bool isRoundingModeRN = getRnd() == RndMode::RN;
141+
bool isRoundingModeRZ = getRnd() == RndMode::RZ;
142+
bool isRoundingModeRP = getRnd() == RndMode::RP;
143+
bool isSatFinite = getSat() == SatMode::SATFINITE;
144+
145+
bool hasRelu = getRelu();
146+
147+
switch (getType()) {
148+
case CVTFP8Type::E4M3:
149+
case CVTFP8Type::E5M2:
150+
if (!isRoundingModeRN)
151+
return emitOpError("Only RN rounding mode is supported for conversions "
152+
"from f32x2 to .e4m3x2 or .e5m2x2 types");
153+
if (!isSatFinite)
154+
return emitOpError("Only SATFINITE saturation mode is supported for "
155+
"conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
156+
break;
157+
case CVTFP8Type::UE8M0:
158+
if (!(isRoundingModeRZ || isRoundingModeRP))
159+
return emitOpError("Only RZ or RP rounding modes are supported for "
160+
"conversions from f32x2 to .ue8m0x2 type");
161+
if (hasRelu)
162+
return emitOpError("relu not supported for conversions to .ue8m0x2 type");
163+
break;
164+
}
165+
return success();
166+
}
167+
168+
LogicalResult CvtF16x2ToF8x2Op::verify() {
169+
if (getType() == CVTFP8Type::UE8M0)
170+
return emitOpError("Only .e4m3 or .e5m2 types are supported for "
171+
"conversions from f16x2 to f8x2.");
172+
173+
return success();
174+
}
175+
176+
LogicalResult CvtBF16x2ToF8x2Op::verify() {
177+
using RndMode = NVVM::FPRoundingMode;
178+
179+
if (getType() != CVTFP8Type::UE8M0)
180+
return emitOpError(
181+
"Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
182+
183+
auto rnd = getRnd();
184+
if (!(rnd == RndMode::RZ || rnd == RndMode::RP))
185+
return emitOpError("Only RZ and RP rounding modes are supported for "
186+
"conversions from bf16x2 to f8x2.");
187+
188+
return success();
189+
}
190+
136191
LogicalResult BulkStoreOp::verify() {
137192
if (getInitVal() != 0)
138193
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1290,17 +1345,81 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
12901345
}
12911346
}
12921347

1293-
#define CVT_TO_F6X2_ID_IMPL(type, has_relu) \
1348+
#define GET_FLOAT_TO_F6x2_ID(type, has_relu) \
12941349
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
12951350
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
12961351

1297-
llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
1298-
bool hasRelu) {
1352+
llvm::Intrinsic::ID CvtF32x2ToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
1353+
bool hasRelu) {
12991354
switch (type) {
13001355
case NVVM::CVTFP6Type::E2M3:
1301-
return CVT_TO_F6X2_ID_IMPL(e2m3x2, hasRelu);
1356+
return GET_FLOAT_TO_F6x2_ID(e2m3x2, hasRelu);
13021357
case NVVM::CVTFP6Type::E3M2:
1303-
return CVT_TO_F6X2_ID_IMPL(e3m2x2, hasRelu);
1358+
return GET_FLOAT_TO_F6x2_ID(e3m2x2, hasRelu);
1359+
}
1360+
}
1361+
1362+
#define GET_FLOAT_TO_F8X2_US_ID(rnd, has_satf) \
1363+
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
1364+
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
1365+
1366+
#define GET_FLOAT_TO_F8X2_S_ID(type, has_relu) \
1367+
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu \
1368+
: llvm::Intrinsic::nvvm_ff_to_##type##_rn
1369+
1370+
llvm::Intrinsic::ID CvtF32x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
1371+
NVVM::FPRoundingMode rnd,
1372+
NVVM::SaturationMode sat,
1373+
bool hasRelu) {
1374+
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1375+
bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
1376+
bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
1377+
1378+
switch (type) {
1379+
case NVVM::CVTFP8Type::E4M3:
1380+
return GET_FLOAT_TO_F8X2_S_ID(e4m3x2, hasRelu);
1381+
case NVVM::CVTFP8Type::E5M2:
1382+
return GET_FLOAT_TO_F8X2_S_ID(e5m2x2, hasRelu);
1383+
case NVVM::CVTFP8Type::UE8M0:
1384+
if (hasRoundingModeRZ)
1385+
return GET_FLOAT_TO_F8X2_US_ID(rz, hasSatFinite);
1386+
else if (hasRoundingModeRP)
1387+
return GET_FLOAT_TO_F8X2_US_ID(rp, hasSatFinite);
1388+
}
1389+
llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
1390+
}
1391+
1392+
#define GET_F16x2_TO_F8X2_ID(type, has_relu) \
1393+
has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu \
1394+
: llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
1395+
1396+
llvm::Intrinsic::ID CvtF16x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
1397+
bool hasRelu) {
1398+
switch (type) {
1399+
case NVVM::CVTFP8Type::E4M3:
1400+
return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
1401+
case NVVM::CVTFP8Type::E5M2:
1402+
return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
1403+
default:
1404+
llvm_unreachable("Invalid CVTFP8Type for CvtF16x2ToF8x2Op");
1405+
}
1406+
}
1407+
1408+
#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
1409+
has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
1410+
: llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
1411+
1412+
llvm::Intrinsic::ID
1413+
CvtBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
1414+
NVVM::SaturationMode sat) {
1415+
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
1416+
switch (rnd) {
1417+
case NVVM::FPRoundingMode::RZ:
1418+
return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
1419+
case NVVM::FPRoundingMode::RP:
1420+
return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
1421+
default:
1422+
llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
13041423
}
13051424
}
13061425

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
22

3-
// CHECK-LABEL: @convert_float_to_fp6x2_packed
4-
llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
3+
// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
4+
llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
55
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
6-
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
6+
%res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16
77
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
8-
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
8+
%res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16
99
llvm.return
1010
}
1111

12-
// CHECK-LABEL: @convert_float_to_fp6x2_vector
13-
llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
12+
// CHECK-LABEL: @convert_f32x2_to_fp6x2_vector
13+
llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
1414
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
1515
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
16-
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
16+
%res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
1717
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
1818
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
19-
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
19+
%res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
2020
llvm.return
2121
}
22-

0 commit comments

Comments
 (0)