@@ -330,21 +330,16 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
330
330
LogicalResult matchAndRewrite (arith::ExtFOp op,
331
331
PatternRewriter &rewriter) const final {
332
332
ImplicitLocOpBuilder b (op.getLoc (), rewriter);
333
- auto operand = op.getOperand ();
333
+ Value operand = op.getOperand ();
334
334
Type operandTy = operand.getType ();
335
335
Type resultTy = op.getType ();
336
336
Type operandETy = getElementTypeOrSelf (operandTy);
337
337
Type resultETy = getElementTypeOrSelf (resultTy);
338
338
339
- if (!operandETy. isF8E8M0FNU ( )) {
339
+ if (!llvm::isa<Float8E8M0FNUType>(operandETy )) {
340
340
return rewriter.notifyMatchFailure (op, " not a ext of F8E8M0FNU" );
341
341
}
342
342
343
- if (!resultETy.isBF16 () && !resultETy.isF16 () && !resultETy.isF32 ()) {
344
- return rewriter.notifyMatchFailure (
345
- op, " not a ext of F8M0FNU on a larger 16-bit or 32-bit width float." );
346
- }
347
-
348
343
Type i8Ty = b.getI8Type ();
349
344
Type i32Ty = b.getI32Type ();
350
345
Type f32Ty = b.getF32Type ();
@@ -368,10 +363,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
368
363
// select for NaNs
369
364
f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
370
365
Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
371
- if (resultETy.isBF16 ()) {
372
- result = b.create <arith::TruncFOp>(resultTy, result);
373
- } else if (resultETy.isF16 ()) {
366
+ if (resultETy.getIntOrFloatBitWidth () < 32 ) {
374
367
result = b.create <arith::TruncFOp>(resultTy, result);
368
+ } else if (resultETy.getIntOrFloatBitWidth () > 32 ) {
369
+ result = b.create <arith::ExtFOp>(resultTy, result);
375
370
}
376
371
rewriter.replaceOp (op, result);
377
372
return success ();
@@ -388,18 +383,14 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
388
383
LogicalResult matchAndRewrite (arith::TruncFOp op,
389
384
PatternRewriter &rewriter) const final {
390
385
ImplicitLocOpBuilder b (op.getLoc (), rewriter);
391
- auto operand = op.getOperand ();
386
+ Value operand = op.getOperand ();
392
387
Type operandTy = operand.getType ();
393
388
Type operandETy = getElementTypeOrSelf (operandTy);
394
389
Type resultTy = op.getType ();
395
390
Type resultETy = getElementTypeOrSelf (resultTy);
396
- if (!resultETy. isF8E8M0FNU ( )) {
391
+ if (!llvm::isa<Float8E8M0FNUType>(resultETy )) {
397
392
return rewriter.notifyMatchFailure (op, " not a truncf to f8E8M0FNU" );
398
393
}
399
- if (!operandETy.isBF16 () && !operandETy.isF16 () && !operandETy.isF32 ()) {
400
- return rewriter.notifyMatchFailure (
401
- op, " not a truncf of 16-bit or 32-bit float to f8E8M0FNU." );
402
- }
403
394
404
395
if (op.getRoundingmodeAttr ()) {
405
396
return rewriter.notifyMatchFailure (
@@ -414,8 +405,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
414
405
i32Ty = shapedTy.clone (i32Ty);
415
406
f32Ty = shapedTy.clone (f32Ty);
416
407
}
417
- if (! operandETy.isF32 () ) {
408
+ if (operandETy.getIntOrFloatBitWidth () < 32 ) {
418
409
operand = b.create <arith::ExtFOp>(f32Ty, operand);
410
+ } else if (operandETy.getIntOrFloatBitWidth () > 32 ) {
411
+ operand = b.create <arith::TruncFOp>(f32Ty, operand);
419
412
}
420
413
Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operand);
421
414
Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
@@ -453,36 +446,37 @@ struct ArithExpandOpsPass
453
446
arith::MinNumFOp
454
447
>();
455
448
456
- if (includeBf16) {
449
+ if (includeBf16) {
457
450
arith::populateExpandBFloat16Patterns (patterns);
458
451
}
459
- if (includeF8E8M0) {
452
+ if (includeF8E8M0) {
460
453
arith::populateExpandF8E8M0Patterns (patterns);
461
454
}
462
- if (includeBf16 || includeF8E8M0) {
463
- target.addDynamicallyLegalOp <arith::ExtFOp>(
464
- [=](arith::ExtFOp op) {
465
- Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
466
- Type outETy = getElementTypeOrSelf (op.getType ());
467
- if (includeBf16 && includeF8E8M0)
468
- return !(inETy.isBF16 () && outETy.isF32 ()) && !(inETy.isF8E8M0FNU () && (outETy.isF32 () || outETy.isBF16 () || outETy.isF16 ()));
469
- if (includeBf16)
470
- return !(inETy.isBF16 () && outETy.isF32 ());
471
- return !(inETy.isF8E8M0FNU () && (outETy.isF32 () || outETy.isBF16 () || outETy.isF16 ()));
472
- });
473
-
474
- target.addDynamicallyLegalOp <arith::TruncFOp>(
475
- [=](arith::TruncFOp op) {
476
- Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
477
- Type outETy = getElementTypeOrSelf (op.getType ());
478
- if (includeBf16 && includeF8E8M0)
479
- return !(inETy.isF32 () && outETy.isBF16 ()) && !(outETy.isF8E8M0FNU () && (inETy.isF32 () || inETy.isF16 () || inETy.isBF16 ()));
480
- if (includeBf16)
481
- return !(inETy.isF32 () && outETy.isBF16 ());
482
- return
483
- !(outETy.isF8E8M0FNU () && (inETy.isF32 () || inETy.isF16 () || inETy.isBF16 ()));
484
- });
485
- }
455
+
456
+ target.addDynamicallyLegalOp <arith::ExtFOp>(
457
+ [=](arith::ExtFOp op) {
458
+ Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
459
+ Type outETy = getElementTypeOrSelf (op.getType ());
460
+ bool legalTypes = true ;
461
+ if (includeBf16)
462
+ legalTypes &= !(inETy.isBF16 () && outETy.isF32 ());
463
+ if (includeF8E8M0)
464
+ legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
465
+ return legalTypes;
466
+ });
467
+
468
+ target.addDynamicallyLegalOp <arith::TruncFOp>(
469
+ [=](arith::TruncFOp op) {
470
+ Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
471
+ Type outETy = getElementTypeOrSelf (op.getType ());
472
+ bool legalTypes = true ;
473
+ if (includeBf16)
474
+ legalTypes &= !(inETy.isF32 () && outETy.isBF16 ());
475
+ if (includeF8E8M0)
476
+ legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
477
+ return legalTypes;
478
+ });
479
+
486
480
// clang-format on
487
481
if (failed (applyPartialConversion (getOperation (), target,
488
482
std::move (patterns))))
0 commit comments