@@ -1148,6 +1148,19 @@ struct ExtPackedFp8OpLowering final
1148
1148
ConversionPatternRewriter &rewriter) const override ;
1149
1149
};
1150
1150
1151
+ struct ScaledExtPackedFp8OpLowering final
1152
+ : public ConvertOpToLLVMPattern<ScaledExtPackedFp8Op> {
1153
+ ScaledExtPackedFp8OpLowering (const LLVMTypeConverter &converter,
1154
+ Chipset chipset)
1155
+ : ConvertOpToLLVMPattern<amdgpu::ScaledExtPackedFp8Op>(converter),
1156
+ chipset (chipset) {}
1157
+ Chipset chipset;
1158
+
1159
+ LogicalResult
1160
+ matchAndRewrite (ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
1161
+ ConversionPatternRewriter &rewriter) const override ;
1162
+ };
1163
+
1151
1164
struct PackedTrunc2xFp8OpLowering final
1152
1165
: public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
1153
1166
PackedTrunc2xFp8OpLowering (const LLVMTypeConverter &converter,
@@ -1161,6 +1174,20 @@ struct PackedTrunc2xFp8OpLowering final
1161
1174
ConversionPatternRewriter &rewriter) const override ;
1162
1175
};
1163
1176
1177
+ struct PackedScaledTrunc2xFp8OpLowering final
1178
+ : public ConvertOpToLLVMPattern<PackedScaledTrunc2xFp8Op> {
1179
+ PackedScaledTrunc2xFp8OpLowering (const LLVMTypeConverter &converter,
1180
+ Chipset chipset)
1181
+ : ConvertOpToLLVMPattern<amdgpu::PackedScaledTrunc2xFp8Op>(converter),
1182
+ chipset (chipset) {}
1183
+ Chipset chipset;
1184
+
1185
+ LogicalResult
1186
+ matchAndRewrite (PackedScaledTrunc2xFp8Op op,
1187
+ PackedScaledTrunc2xFp8OpAdaptor adaptor,
1188
+ ConversionPatternRewriter &rewriter) const override ;
1189
+ };
1190
+
1164
1191
struct PackedStochRoundFp8OpLowering final
1165
1192
: public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
1166
1193
PackedStochRoundFp8OpLowering (const LLVMTypeConverter &converter,
@@ -1229,6 +1256,67 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
1229
1256
}
1230
1257
return success ();
1231
1258
}
1259
+ // rocdl.cvt.scalef32.pk.f32.fp8 %source[false]: i32, %c4: f32 : vector<2xf32>
1260
+ // rocdl.cvt.scalef32.f32.fp8 %source[0], %c4 : f32
1261
+
1262
+ // amdgpu.scaled_ext_packed_fp8 %v[0]: f8E5M2, %scale: f32 : f8E5M2 to
1263
+ // vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]: vector<2xf8E5M2>, %scale:
1264
+ // f32 : vector<2xf8E5M2> to vector<2xf32> amdgpu.scaled_ext_packed_fp8 %v[0]:
1265
+ // vector<4xf8E5M2>, %scale: f32 : vector<4xf8E5M2> to vector<2xf32>
1266
+ LogicalResult ScaledExtPackedFp8OpLowering::matchAndRewrite (
1267
+ ScaledExtPackedFp8Op op, ScaledExtPackedFp8OpAdaptor adaptor,
1268
+ ConversionPatternRewriter &rewriter) const {
1269
+ Location loc = op.getLoc ();
1270
+ if (chipset != kGfx950 )
1271
+ return rewriter.notifyMatchFailure (
1272
+ loc, " Scaled fp8 conversion instructions are not available on target "
1273
+ " architecture and their emulation is not implemented" );
1274
+ Type v4i8 =
1275
+ getTypeConverter ()->convertType (VectorType::get (4 , rewriter.getI8Type ()));
1276
+ Type i32 = getTypeConverter ()->convertType (rewriter.getI32Type ());
1277
+ Type f32 = getTypeConverter ()->convertType (op.getResult ().getType ());
1278
+
1279
+ Value source = adaptor.getSource ();
1280
+ Value scale = adaptor.getScale ();
1281
+ auto sourceVecType = dyn_cast<VectorType>(op.getSource ().getType ());
1282
+ auto resultVecType = dyn_cast<VectorType>(op.getResult ().getType ());
1283
+ Type sourceElemType = getElementTypeOrSelf (op.getSource ());
1284
+ // Extend to a v4i8
1285
+ if (!sourceVecType || sourceVecType.getNumElements () < 4 ) {
1286
+ Value longVec = rewriter.create <LLVM::UndefOp>(loc, v4i8);
1287
+ if (!sourceVecType) {
1288
+ longVec = rewriter.create <LLVM::InsertElementOp>(
1289
+ loc, longVec, source, createI32Constant (rewriter, loc, 0 ));
1290
+ } else {
1291
+ for (int32_t i = 0 , e = sourceVecType.getNumElements (); i < e; ++i) {
1292
+ Value idx = createI32Constant (rewriter, loc, i);
1293
+ Value elem = rewriter.create <LLVM::ExtractElementOp>(loc, source, idx);
1294
+ longVec =
1295
+ rewriter.create <LLVM::InsertElementOp>(loc, longVec, elem, idx);
1296
+ }
1297
+ }
1298
+ source = longVec;
1299
+ }
1300
+ Value i32Source = rewriter.create <LLVM::BitcastOp>(loc, i32 , source);
1301
+ if (resultVecType) {
1302
+ if (typeIsExpectedBf8ForChipset (chipset, sourceElemType)) {
1303
+ rewriter.replaceOpWithNewOp <ROCDL::CvtScaleF32PkF32Bf8Op>(
1304
+ op, f32 , i32Source, scale, op.getIndex ());
1305
+ } else if (typeIsExpectedFp8ForChipset (chipset, sourceElemType)) {
1306
+ rewriter.replaceOpWithNewOp <ROCDL::CvtScaleF32PkF32Fp8Op>(
1307
+ op, f32 , i32Source, scale, op.getIndex ());
1308
+ }
1309
+ } else {
1310
+ if (typeIsExpectedBf8ForChipset (chipset, sourceElemType)) {
1311
+ rewriter.replaceOpWithNewOp <ROCDL::CvtScaleF32F32Bf8Op>(
1312
+ op, f32 , i32Source, scale, op.getIndex ());
1313
+ } else if (typeIsExpectedFp8ForChipset (chipset, sourceElemType)) {
1314
+ rewriter.replaceOpWithNewOp <ROCDL::CvtScaleF32F32Fp8Op>(
1315
+ op, f32 , i32Source, scale, op.getIndex ());
1316
+ }
1317
+ }
1318
+ return success ();
1319
+ }
1232
1320
1233
1321
LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite (
1234
1322
PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
@@ -1266,6 +1354,46 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
1266
1354
return success ();
1267
1355
}
1268
1356
1357
+ // rocdl.cvt.scalef32.pk.fp8.f32 %sourceA: f32, %sourceB: f32, %c0: f32 ->
1358
+ // %old[false]: vector<2xi16> : vector<2xi16>
1359
+ LogicalResult PackedScaledTrunc2xFp8OpLowering::matchAndRewrite (
1360
+ PackedScaledTrunc2xFp8Op op, PackedScaledTrunc2xFp8OpAdaptor adaptor,
1361
+ ConversionPatternRewriter &rewriter) const {
1362
+ Location loc = op.getLoc ();
1363
+ if (chipset != kGfx950 )
1364
+ return rewriter.notifyMatchFailure (
1365
+ loc, " Scaled fp8 conversion instructions are not available on target "
1366
+ " architecture and their emulation is not implemented" );
1367
+ Type v2i16 = getTypeConverter ()->convertType (
1368
+ VectorType::get (2 , rewriter.getI16Type ()));
1369
+
1370
+ Type resultType = op.getResult ().getType ();
1371
+ Type resultElemType = getElementTypeOrSelf (resultType);
1372
+
1373
+ Value sourceA = adaptor.getSourceA ();
1374
+ Value sourceB = adaptor.getSourceB ();
1375
+ Value scale = adaptor.getScale ();
1376
+ if (!sourceB)
1377
+ sourceB = rewriter.create <LLVM::UndefOp>(loc, sourceA.getType ());
1378
+ Value existing = adaptor.getExisting ();
1379
+ if (existing)
1380
+ existing = rewriter.create <LLVM::BitcastOp>(loc, v2i16, existing);
1381
+ else
1382
+ existing = rewriter.create <LLVM::UndefOp>(loc, v2i16);
1383
+
1384
+ Value result;
1385
+ if (typeIsExpectedBf8ForChipset (chipset, resultElemType))
1386
+ result = rewriter.create <ROCDL::CvtScaleF32PkBf8F32Op>(
1387
+ loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex ());
1388
+ else if (typeIsExpectedFp8ForChipset (chipset, resultElemType))
1389
+ result = rewriter.create <ROCDL::CvtScaleF32PkFp8F32Op>(
1390
+ loc, v2i16, existing, sourceA, sourceB, scale, op.getWordIndex ());
1391
+
1392
+ result = rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(
1393
+ op, getTypeConverter ()->convertType (resultType), result);
1394
+ return success ();
1395
+ }
1396
+
1269
1397
LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite (
1270
1398
PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
1271
1399
ConversionPatternRewriter &rewriter) const {
@@ -1547,7 +1675,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
1547
1675
ROCDL::RawPtrBufferAtomicCmpSwap>,
1548
1676
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1549
1677
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1550
- ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
1678
+ ExtPackedFp8OpLowering, ScaledExtPackedFp8OpLowering,
1679
+ PackedTrunc2xFp8OpLowering, PackedScaledTrunc2xFp8OpLowering,
1551
1680
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1552
1681
chipset);
1553
1682
patterns.add <AMDGPUSwizzleBitModeLowering>(converter);
0 commit comments