Skip to content

Commit a875a16

Browse files
committed
[mlir][AMDGPU] implement ScaledExtPackedFp8Op and PackedScaledTrunc2xFp8Op
1 parent 936bf29 commit a875a16

File tree

5 files changed

+321
-1
lines changed

5 files changed

+321
-1
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,33 @@ def AMDGPU_ExtPackedFp8Op :
112112
}];
113113
}
114114

115+
def AMDGPU_ScaledExtPackedFp8Op :
116+
AMDGPU_Op<"scaled_ext_packed_fp8", [Pure]>,
117+
Arguments<(ins AnyTypeOf<[F8E5M2, F8E4M3FN,
118+
VectorOfLengthAndType<[1, 2, 3, 4], [F8E5M2, F8E4M3FN]>]>:$source,
119+
F32:$scale,
120+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$index)>,
121+
Results<(outs AnyTypeOf<[F32, FixedVectorOfLengthAndType<[2], [F32]>]>:$res)> {
122+
let summary = "Extend a fp8 value to a float or a vector of packed fp8 values to two floats";
123+
124+
let description = [{
125+
Extend and scale one or two 8-bit floats in `source[index]` to a 32-bit float or
126+
two floats and return them.
127+
128+
This rather unusual signature arises from the fact that AMD GPUs cannot
129+
easily work with sub 32-bit quantities, so the compiler intrinsics for
130+
extending 8-bit floats (which are, currently, the only way to work with
131+
this operation) take packed vectors of 2 such floats.
132+
133+
If the passed-in vector has fewer than two elements, or the input is scalar,
134+
the remaining values in the <2 x i8> will be filled with
135+
undefined values as needed.
136+
}];
137+
let assemblyFormat = [{
138+
attr-dict $source `[` $index `]` `,` $scale `:` type($source) `to` type($res)
139+
}];
140+
}
141+
115142
def AMDGPU_PackedTrunc2xFp8Op :
116143
AMDGPU_Op<"packed_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
117144
Arguments<(ins F32:$sourceA,
@@ -139,6 +166,35 @@ def AMDGPU_PackedTrunc2xFp8Op :
139166
let hasVerifier = 1;
140167
}
141168

169+
def AMDGPU_PackedScaledTrunc2xFp8Op :
170+
AMDGPU_Op<"packed_scaled_trunc_2xfp8", [Pure, AttrSizedOperandSegments]>,
171+
Arguments<(ins F32:$sourceA,
172+
Optional<F32>:$sourceB,
173+
F32:$scale,
174+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$wordIndex,
175+
Optional<FixedVectorOfLengthAndType<[4], [F8E5M2, F8E4M3FN]>>:$existing)>,
176+
Results<(outs FixedVectorOfLengthAndType<[4], [F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, F8E5M2]>:$res)> {
177+
let summary = "Round two floats into a packed vector of 8-bit floats";
178+
let description = [{
179+
Scale and round the inputs `sourceA` and `sourceB` (which is undefined if not
180+
specified) into the low or high word (bottom two or top two) elements
181+
of the returned vector, keeping the other two elements of `existing`
182+
unchanged if present (or undefined if it was not passed in).
183+
184+
The reason for this odd signature is that AMD GPUs cannot easily work with
185+
sub-registers, and so the conversion intrinsics (which are currently the
186+
only way to work with 8-bit float types) take packed vectors of 4 8-bit
187+
values.
188+
}];
189+
let assemblyFormat = [{
190+
attr-dict $sourceA `,` ($sourceB^):(`undef`)?
191+
`into` ($existing^):(`undef`)? `[` `word` $wordIndex `]`
192+
`,` $scale
193+
`:` type($sourceA) `to` type($res) (`into` type($existing)^)?
194+
}];
195+
let hasVerifier = 1;
196+
}
197+
142198
def AMDGPU_PackedStochRoundFp8Op :
143199
AMDGPU_Op<"packed_stoch_round_fp8", [Pure]>,
144200
Arguments<(ins F32:$source,

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,19 @@ struct ExtPackedFp8OpLowering final
11481148
ConversionPatternRewriter &rewriter) const override;
11491149
};
11501150

1151+
struct ScaledExtPackedFp8OpLowering final
1152+
: public ConvertOpToLLVMPattern<ScaledExtPackedFp8Op> {
1153+
ScaledExtPackedFp8OpLowering(const LLVMTypeConverter &converter,
1154+
Chipset chipset)
1155+
: ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedFp8Op>(converter),
1156+
chipset(chipset) {}
1157+
Chipset chipset;
1158+
1159+
LogicalResult
1160+
matchAndRewrite(ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
1161+
ConversionPatternRewriter &rewriter) const override;
1162+
};
1163+
11511164
struct PackedTrunc2xFp8OpLowering final
11521165
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
11531166
PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1161,6 +1174,20 @@ struct PackedTrunc2xFp8OpLowering final
11611174
ConversionPatternRewriter &rewriter) const override;
11621175
};
11631176

1177+
struct PackedScaledTrunc2xFp8OpLowering final
1178+
: public ConvertOpToLLVMPattern<PackedScaledTrunc2xFp8Op> {
1179+
PackedScaledTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
1180+
Chipset chipset)
1181+
: ConvertOpToLLVMPattern<amdgpu::PackedScaledTrunc2xFp8Op>(converter),
1182+
chipset(chipset) {}
1183+
Chipset chipset;
1184+
1185+
LogicalResult
1186+
matchAndRewrite(PackedScaledTrunc2xFp8Op op,
1187+
PackedScaledTrunc2xFp8OpAdaptor adaptor,
1188+
ConversionPatternRewriter &rewriter) const override;
1189+
};
1190+
11641191
struct PackedStochRoundFp8OpLowering final
11651192
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
11661193
PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
@@ -1229,6 +1256,67 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
12291256
}
12301257
return success();
12311258
}
1259+
// rocdl.cvt.scalef32.pk.f32.fp8 %source[false]: i32, %c4: f32 : vector<2xf32>
1260+
// rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
1261+
1262+
// amdgpu.scaled_ext_packed_fp8 %v[0]: f8E5M2, %scale: f32 : f8E5M2 to
1263+
// vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]: vector<2xf8E5M2>, %scale:
1264+
// f32 : vector<2xf8E5M2> to vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]:
1265+
// vector<4xf8E5M2>, %scale: f32 : vector<4xf8E5M2> to vector<2xf32>
1266+
LogicalResult ScaledExtPackedFp8OpLowering::matchAndRewrite(
1267+
ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
1268+
ConversionPatternRewriter &rewriter) const {
1269+
Location loc = op.getLoc();
1270+
if (chipset != kGfx950)
1271+
return rewriter.notifyMatchFailure(
1272+
loc, "Scaled fp8 conversion instructions are not available on target "
1273+
"architecture and their emulation is not implemented");
1274+
Type v4i8 =
1275+
getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
1276+
Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
1277+
Type f32 = getTypeConverter()->convertType(op.getResult().getType());
1278+
1279+
Value source = adaptor.getSource();
1280+
Value scale = adaptor.getScale();
1281+
auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
1282+
auto resultVecType = dyn_cast<VectorType>(op.getResult().getType());
1283+
Type sourceElemType = getElementTypeOrSelf(op.getSource());
1284+
// Extend to a v4i8
1285+
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
1286+
Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
1287+
if (!sourceVecType) {
1288+
longVec = rewriter.create<LLVM::InsertElementOp>(
1289+
loc, longVec, source, createI32Constant(rewriter, loc, 0));
1290+
} else {
1291+
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
1292+
Value idx = createI32Constant(rewriter, loc, i);
1293+
Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
1294+
longVec =
1295+
rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
1296+
}
1297+
}
1298+
source = longVec;
1299+
}
1300+
Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
1301+
if (resultVecType) {
1302+
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1303+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
1304+
op, f32, i32Source, scale, op.getIndex());
1305+
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1306+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Fp8Op>(
1307+
op, f32, i32Source, scale, op.getIndex());
1308+
}
1309+
} else {
1310+
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
1311+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Bf8Op>(
1312+
op, f32, i32Source, scale, op.getIndex());
1313+
} else if (typeIsExpectedFp8ForChipset(chipset, sourceElemType)) {
1314+
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32F32Fp8Op>(
1315+
op, f32, i32Source, scale, op.getIndex());
1316+
}
1317+
}
1318+
return success();
1319+
}
12321320

12331321
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
12341322
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
@@ -1266,6 +1354,46 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
12661354
return success();
12671355
}
12681356

1357+
// rocdl.cvt.scalef32.pk.fp8.f32 %sourceA: f32, %sourceB: f32, %c0: f32 ->
1358+
// %old[false]: vector<2xi16> : vector<2xi16>
1359+
LogicalResult PackedScaledTrunc2xFp8OpLowering::matchAndRewrite(
1360+
PackedScaledTrunc2xFp8Op op, PackedScaledTrunc2xFp8OpAdaptor adaptor,
1361+
ConversionPatternRewriter &rewriter) const {
1362+
Location loc = op.getLoc();
1363+
if (chipset != kGfx950)
1364+
return rewriter.notifyMatchFailure(
1365+
loc, "Scaled fp8 conversion instructions are not available on target "
1366+
"architecture and their emulation is not implemented");
1367+
Type v2i16 = getTypeConverter()->convertType(
1368+
VectorType::get(2, rewriter.getI16Type()));
1369+
1370+
Type resultType = op.getResult().getType();
1371+
Type resultElemType = getElementTypeOrSelf(resultType);
1372+
1373+
Value sourceA = adaptor.getSourceA();
1374+
Value sourceB = adaptor.getSourceB();
1375+
Value scale = adaptor.getScale();
1376+
if (!sourceB)
1377+
sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
1378+
Value existing = adaptor.getExisting();
1379+
if (existing)
1380+
existing = rewriter.create<LLVM::BitcastOp>(loc, v2i16, existing);
1381+
else
1382+
existing = rewriter.create<LLVM::UndefOp>(loc, v2i16);
1383+
1384+
Value result;
1385+
if (typeIsExpectedBf8ForChipset(chipset, resultElemType))
1386+
result = rewriter.create<ROCDL::CvtScaleF32PkBf8F32Op>(
1387+
loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
1388+
else if (typeIsExpectedFp8ForChipset(chipset, resultElemType))
1389+
result = rewriter.create<ROCDL::CvtScaleF32PkFp8F32Op>(
1390+
loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex());
1391+
1392+
result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
1393+
op, getTypeConverter()->convertType(resultType), result);
1394+
return success();
1395+
}
1396+
12691397
LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
12701398
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
12711399
ConversionPatternRewriter &rewriter) const {
@@ -1547,7 +1675,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
15471675
ROCDL::RawPtrBufferAtomicCmpSwap>,
15481676
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
15491677
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1550-
ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
1678+
ExtPackedFp8OpLowering, ScaledExtPackedFp8OpLowering,
1679+
PackedTrunc2xFp8OpLowering, PackedScaledTrunc2xFp8OpLowering,
15511680
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
15521681
chipset);
15531682
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ LogicalResult PackedTrunc2xFp8Op::verify() {
5454
return success();
5555
}
5656

57+
LogicalResult PackedScaledTrunc2xFp8Op::verify() {
58+
if (getExisting() && getExisting().getType() != getResult().getType())
59+
return emitOpError("existing values must have same type as result");
60+
return success();
61+
}
62+
5763
LogicalResult PackedStochRoundFp8Op::verify() {
5864
if (getExisting() && getExisting().getType() != getResult().getType())
5965
return emitOpError("existing values must have same type as result");
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
2+
3+
// CHECK-LABEL: func @scaled_ext_scalar
4+
// CHECK-SAME: ([[IN:%.+]]: f8E5M2, [[SCALE:%.+]]: f32)
5+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : f8E5M2 to i8
6+
// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
7+
// CHECK-DAG: [[C0_1:%.+]] = llvm.mlir.constant(0 : i32) : i32
8+
// CHECK: [[VEC:%.+]] = llvm.insertelement [[V]], [[UNDEF]]{{\[}}[[C0_1]] : i32] : vector<4xi8>
9+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC]] : vector<4xi8> to i32
10+
// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.bf8 [[CAST]][0], [[SCALE]] : f32
11+
// CHECK: return [[EXT]] : f32
12+
func.func @scaled_ext_scalar(%v: f8E5M2, %scale: f32) -> f32 {
13+
%ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale: f8E5M2 to f32
14+
func.return %ret : f32
15+
}
16+
17+
// CHECK-LABEL: func @scaled_ext_short_vec
18+
// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
19+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
20+
// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
21+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
22+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
23+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
24+
// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
25+
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
26+
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
27+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
28+
// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][1], [[SCALE]] : f32
29+
// CHECK: return [[EXT]] : f32
30+
func.func @scaled_ext_short_vec(%v: vector<2xf8E4M3FN>, %scale: f32) -> f32 {
31+
%ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<2xf8E4M3FN> to f32
32+
func.return %ret : f32
33+
}
34+
35+
// CHECK-LABEL: func @scaled_ext_full_vec
36+
// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
37+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
38+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
39+
// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.f32.fp8 [[CAST]][3], [[SCALE]] : f32
40+
// CHECK: return [[EXT]] : f32
41+
func.func @scaled_ext_full_vec(%v: vector<4xf8E4M3FN>, %scale: f32) -> f32 {
42+
%ret = amdgpu.scaled_ext_packed_fp8 %v[3], %scale : vector<4xf8E4M3FN> to f32
43+
func.return %ret : f32
44+
}
45+
46+
// CHECK-LABEL: func @scaled_ext_packed_2xfp8
47+
// CHECK-SAME: ([[IN:%.+]]: vector<2xf8E4M3FN>, [[SCALE:%.+]]: f32)
48+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<2xf8E4M3FN> to vector<2xi8>
49+
// CHECK-DAG: [[UNDEF:%.+]] = llvm.mlir.undef : vector<4xi8>
50+
// CHECK-DAG: [[C0:%.+]] = llvm.mlir.constant(0 : i32) : i32
51+
// CHECK: [[ELEM_0:%.+]] = llvm.extractelement [[V]]{{\[}}[[C0]] : i32] : vector<2xi8>
52+
// CHECK: [[VEC_0:%.+]] = llvm.insertelement [[ELEM_0]], [[UNDEF]]{{\[}}[[C0]] : i32] : vector<4xi8>
53+
// CHECK: [[C1_1:%.+]] = llvm.mlir.constant(1 : i32) : i32
54+
// CHECK: [[ELEM_1:%.+]] = llvm.extractelement [[V]]{{\[}}[[C1_1]] : i32] : vector<2xi8>
55+
// CHECK: [[VEC_1:%.+]] = llvm.insertelement [[ELEM_1]], [[VEC_0]]{{\[}}[[C1_1]] : i32] : vector<4xi8>
56+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[VEC_1]] : vector<4xi8> to i32
57+
// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][false], [[SCALE]] : vector<2xf32>
58+
// CHECK: return [[EXT]]
59+
func.func @scaled_ext_packed_2xfp8(%v: vector<2xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
60+
%ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<2xf8E4M3FN> to vector<2xf32>
61+
func.return %ret : vector<2xf32>
62+
}
63+
64+
// CHECK-LABEL: func @scaled_ext_packed_4xfp8
65+
// CHECK-SAME: ([[IN:%.+]]: vector<4xf8E4M3FN>, [[SCALE:%.+]]: f32)
66+
// CHECK: [[V:%.+]] = builtin.unrealized_conversion_cast [[IN]] : vector<4xf8E4M3FN> to vector<4xi8>
67+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[V]] : vector<4xi8> to i32
68+
// CHECK: [[EXT:%.+]] = rocdl.cvt.scalef32.pk.f32.fp8 [[CAST]][true], [[SCALE]] : vector<2xf32>
69+
// CHECK: return [[EXT]] : vector<2xf32>
70+
func.func @scaled_ext_packed_4xfp8(%v: vector<4xf8E4M3FN>, %scale: f32) -> vector<2xf32> {
71+
%ret = amdgpu.scaled_ext_packed_fp8 %v[1], %scale : vector<4xf8E4M3FN> to vector<2xf32>
72+
func.return %ret : vector<2xf32>
73+
}
74+
75+
// CHECK-LABEL: func @packed_scaled_trunc
76+
// CHECK-SAME: ([[V:%.+]]: f32, [[SCALE:%.+]]: f32)
77+
// CHECK: [[V2:%.+]] = llvm.mlir.undef : f32
78+
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
79+
// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[V2]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
80+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
81+
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
82+
func.func @packed_scaled_trunc(%v: f32, %scale: f32) -> vector<4xf8E4M3FN> {
83+
%ret = amdgpu.packed_scaled_trunc_2xfp8 %v, undef into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
84+
func.return %ret : vector<4xf8E4M3FN>
85+
}
86+
87+
// CHECK-LABEL: func @packed_scaled_truncx2
88+
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[SCALE:%.+]]: f32)
89+
// CHECK: [[EXISTING:%.+]] = llvm.mlir.undef : vector<2xi16>
90+
// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.fp8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING]][false] : vector<2xi16>
91+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
92+
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E4M3FN>
93+
func.func @packed_scaled_truncx2(%v: f32, %w: f32, %scale: f32) -> vector<4xf8E4M3FN> {
94+
%ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into undef[word 0], %scale : f32 to vector<4xf8E4M3FN>
95+
func.return %ret : vector<4xf8E4M3FN>
96+
}
97+
98+
// CHECK-LABEL: func @packed_scaled_truncx2_into
99+
// CHECK-SAME: ([[V:%.+]]: f32, [[W:%.+]]: f32, [[EXISTING:%.+]]: vector<4xf8E5M2>, [[SCALE:%.+]]: f32)
100+
// CHECK: [[EXISTING_BYTES:%.+]] = builtin.unrealized_conversion_cast [[EXISTING]] : vector<4xf8E5M2> to vector<4xi8>
101+
// CHECK: [[EXISTING_INT:%.+]] = llvm.bitcast [[EXISTING_BYTES]] : vector<4xi8> to vector<2xi16>
102+
// CHECK: [[PACKED:%.+]] = rocdl.cvt.scalef32.pk.bf8.f32 [[V]], [[W]], [[SCALE]] -> [[EXISTING_INT]][true] : vector<2xi16>
103+
// CHECK: [[CAST:%.+]] = llvm.bitcast [[PACKED]] : vector<2xi16> to vector<4xi8>
104+
// CHECK: builtin.unrealized_conversion_cast [[CAST]] : vector<4xi8> to vector<4xf8E5M2>
105+
func.func @packed_scaled_truncx2_into(%v: f32, %w: f32, %existing: vector<4xf8E5M2>, %scale: f32) -> vector<4xf8E5M2> {
106+
%ret = amdgpu.packed_scaled_trunc_2xfp8 %v, %w into %existing[word 1], %scale : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
107+
func.return %ret : vector<4xf8E5M2>
108+
}

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,34 @@ func.func @ext_packed_fp8_v(%v: vector<4xf8E4M3FNUZ>) -> vector<2xf32> {
1818
func.return %ret : vector<2xf32>
1919
}
2020

21+
// CHECK-LABEL: func @scaled_ext_packed_fp8_s
22+
// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to f32
23+
func.func @scaled_ext_packed_fp8_s(%v: vector<4xf8E5M2>, %scale: f32) -> f32 {
24+
%ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to f32
25+
func.return %ret : f32
26+
}
27+
28+
// CHECK-LABEL: func @scaled_ext_packed_fp8_v
29+
// CHECK: amdgpu.scaled_ext_packed_fp8 {{.*}} vector<4xf8E5M2> to vector<2xf32
30+
func.func @scaled_ext_packed_fp8_v(%v: vector<4xf8E5M2>, %scale: f32) -> vector<2xf32> {
31+
%ret = amdgpu.scaled_ext_packed_fp8 %v[0], %scale : vector<4xf8E5M2> to vector<2xf32>
32+
func.return %ret : vector<2xf32>
33+
}
34+
2135
// CHECK-LABEL: func @packed_trunc_2xfp8
2236
// CHECK: amdgpu.packed_trunc_2xfp8
2337
func.func @packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2FNUZ>, %stoch: i32) -> vector<4xf8E5M2FNUZ> {
2438
%ret = amdgpu.packed_trunc_2xfp8 %v1, %v2 into %others[word 1] : f32 to vector<4xf8E5M2FNUZ> into vector<4xf8E5M2FNUZ>
2539
func.return %ret : vector<4xf8E5M2FNUZ>
2640
}
2741

42+
// CHECK-LABEL: func @scaled_packed_trunc_2xfp8
43+
// CHECK: amdgpu.packed_scaled_trunc_2xfp8
44+
func.func @scaled_packed_trunc_2xfp8(%v1: f32, %v2: f32, %others: vector<4xf8E5M2>, %scale: f32) -> vector<4xf8E5M2> {
45+
%ret = amdgpu.packed_scaled_trunc_2xfp8 %v1, %v2 into %others[word 1], %scale : f32 to vector<4xf8E5M2> into vector<4xf8E5M2>
46+
func.return %ret : vector<4xf8E5M2>
47+
}
48+
2849
// CHECK-LABEL: func @packed_stoch_round_fp8
2950
// CHECK: amdgpu.packed_stoch_round_fp8
3051
func.func @packed_stoch_round_fp8(%v1: f32, %stoch: i32, %others: vector<4xf8E5M2FNUZ>) -> vector<4xf8E5M2FNUZ> {

0 commit comments

Comments
 (0)