9
9
#include " mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
10
10
11
11
#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12
+ #include " mlir/Dialect/AMDGPU/Utils/Chipset.h"
12
13
#include " mlir/Dialect/Arith/IR/Arith.h"
13
14
#include " mlir/Dialect/Arith/Utils/Utils.h"
15
+ #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
16
+ #include " mlir/Dialect/LLVMIR/ROCDLDialect.h"
14
17
#include " mlir/Dialect/Vector/IR/VectorOps.h"
15
18
#include " mlir/IR/BuiltinTypes.h"
16
19
#include " mlir/IR/PatternMatch.h"
@@ -24,6 +27,7 @@ namespace mlir {
24
27
} // namespace mlir
25
28
26
29
using namespace mlir ;
30
+ using namespace mlir ::amdgpu;
27
31
28
32
namespace {
29
33
struct ArithToAMDGPUConversionPass final
@@ -43,12 +47,25 @@ struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
43
47
44
48
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
45
49
bool saturateFP8 = false ;
46
- TruncFToFloat8RewritePattern (MLIRContext *ctx, bool saturateFP8)
47
- : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
50
+ TruncFToFloat8RewritePattern (MLIRContext *ctx, bool saturateFP8,
51
+ Chipset chipset)
52
+ : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
53
+ chipset (chipset) {}
54
+ Chipset chipset;
48
55
49
56
LogicalResult match (arith::TruncFOp op) const override ;
50
57
void rewrite (arith::TruncFOp op, PatternRewriter &rewriter) const override ;
51
58
};
59
+
60
+ struct TruncfToFloat16RewritePattern final
61
+ : public OpRewritePattern<arith::TruncFOp> {
62
+
63
+ using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
64
+
65
+ LogicalResult match (arith::TruncFOp op) const override ;
66
+ void rewrite (arith::TruncFOp op, PatternRewriter &rewriter) const override ;
67
+ };
68
+
52
69
} // end namespace
53
70
54
71
static Value castF32To (Type elementType, Value f32 , Location loc,
@@ -272,17 +289,105 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
272
289
rewriter.replaceOp (op, result);
273
290
}
274
291
292
+ LogicalResult TruncfToFloat16RewritePattern::match (arith::TruncFOp op) const {
293
+ Type outType = op.getOut ().getType ();
294
+ Type inputType = getElementTypeOrSelf (op.getIn ());
295
+ if (auto outVecType = dyn_cast<VectorType>(outType)) {
296
+ if (outVecType.isScalable ())
297
+ return failure ();
298
+ outType = outVecType.getElementType ();
299
+ }
300
+ return success (outType.isF16 () && inputType.isF32 ());
301
+ }
302
+
303
+ void TruncfToFloat16RewritePattern::rewrite (arith::TruncFOp op,
304
+ PatternRewriter &rewriter) const {
305
+ Location loc = op.getLoc ();
306
+ Value in = op.getIn ();
307
+ Type outElemType = getElementTypeOrSelf (op.getOut ().getType ());
308
+ VectorType truncResType = VectorType::get (2 , outElemType);
309
+ auto inVectorTy = dyn_cast<VectorType>(in.getType ());
310
+
311
+ // Handle the case where input type is not a vector type
312
+ if (!inVectorTy) {
313
+ auto sourceB = rewriter.create <LLVM::PoisonOp>(loc, rewriter.getF32Type ());
314
+ Value asF16s =
315
+ rewriter.create <ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
316
+ Value result = rewriter.create <vector::ExtractElementOp>(
317
+ loc, asF16s, rewriter.createOrFold <arith::ConstantIndexOp>(loc, 0 ));
318
+ return rewriter.replaceOp (op, result);
319
+ }
320
+ VectorType outType = cast<VectorType>(op.getOut ().getType ());
321
+ int64_t numElements = outType.getNumElements ();
322
+ Value zero = rewriter.createOrFold <arith::ConstantOp>(
323
+ loc, outElemType, rewriter.getFloatAttr (outElemType, 0.0 ));
324
+ Value result = rewriter.createOrFold <vector::SplatOp>(loc, outType, zero);
325
+
326
+ if (inVectorTy.getRank () > 1 ) {
327
+ inVectorTy = VectorType::get (SmallVector<int64_t >{numElements},
328
+ inVectorTy.getElementType ());
329
+ in = rewriter.create <vector::ShapeCastOp>(loc, inVectorTy, in);
330
+ }
331
+
332
+ // Handle the vector case. We also handle the (uncommon) case where the vector
333
+ // length is odd
334
+ for (int64_t i = 0 ; i < numElements; i += 2 ) {
335
+ int64_t elemsThisOp = std::min (numElements, i + 2 ) - i;
336
+ Value thisResult = nullptr ;
337
+ Value elemA = rewriter.create <vector::ExtractElementOp>(
338
+ loc, in, rewriter.create <arith::ConstantIndexOp>(loc, i));
339
+ Value elemB = rewriter.create <LLVM::PoisonOp>(loc, rewriter.getF32Type ());
340
+
341
+ if (elemsThisOp == 2 ) {
342
+ elemB = rewriter.create <vector::ExtractElementOp>(
343
+ loc, in, rewriter.createOrFold <arith::ConstantIndexOp>(loc, i + 1 ));
344
+ }
345
+
346
+ thisResult =
347
+ rewriter.create <ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
348
+ // Place back the truncated result into the possibly larger vector. If we
349
+ // are operating on a size 2 vector, these operations should be folded away
350
+ thisResult = rewriter.create <vector::ExtractStridedSliceOp>(
351
+ loc, thisResult, 0 , elemsThisOp, 1 );
352
+ result = rewriter.create <vector::InsertStridedSliceOp>(loc, thisResult,
353
+ result, i, 1 );
354
+ }
355
+
356
+ if (inVectorTy.getRank () != outType.getRank ()) {
357
+ result = rewriter.create <vector::ShapeCastOp>(loc, outType, result);
358
+ }
359
+
360
+ rewriter.replaceOp (op, result);
361
+ }
362
+
275
363
void mlir::arith::populateArithToAMDGPUConversionPatterns (
276
- RewritePatternSet &patterns, bool saturateFP8TruncF) {
277
- patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext ());
278
- patterns.add <TruncFToFloat8RewritePattern>(patterns.getContext (),
279
- saturateFP8TruncF);
364
+ RewritePatternSet &patterns, bool convertFP8Arithmetic,
365
+ bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
366
+
367
+ if (convertFP8Arithmetic) {
368
+ patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext ());
369
+ patterns.add <TruncFToFloat8RewritePattern>(patterns.getContext (),
370
+ saturateFP8Truncf, chipset);
371
+ }
372
+ if (allowPackedF16Rtz)
373
+ patterns.add <TruncfToFloat16RewritePattern>(patterns.getContext ());
280
374
}
281
375
282
376
void ArithToAMDGPUConversionPass::runOnOperation () {
283
377
Operation *op = getOperation ();
378
+ MLIRContext *ctx = &getContext ();
284
379
RewritePatternSet patterns (op->getContext ());
285
- arith::populateArithToAMDGPUConversionPatterns (patterns, saturateFP8Truncf);
380
+ FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse (chipset);
381
+ if (failed (maybeChipset)) {
382
+ emitError (UnknownLoc::get (ctx), " Invalid chipset name: " + chipset);
383
+ return signalPassFailure ();
384
+ }
385
+
386
+ bool convertFP8Arithmetic =
387
+ (*maybeChipset).majorVersion == 9 && (*maybeChipset).minorVersion >= 0x40 ;
388
+ arith::populateArithToAMDGPUConversionPatterns (
389
+ patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
390
+ *maybeChipset);
286
391
if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns))))
287
392
return signalPassFailure ();
288
393
}
0 commit comments