Skip to content

Commit a3dee48

Browse files
committed
Address review comments
1 parent 25a8fdd commit a3dee48

File tree

3 files changed

+37
-45
lines changed

3 files changed

+37
-45
lines changed

mlir/include/mlir/IR/Types.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ class Type {
109109
// Convenience predicates. This is only for floating point types,
110110
// derived types should use isa/dyn_cast.
111111
bool isIndex() const;
112-
bool isF8E8M0FNU() const;
113112
bool isBF16() const;
114113
bool isF16() const;
115114
bool isTF32() const;

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -330,21 +330,16 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
330330
LogicalResult matchAndRewrite(arith::ExtFOp op,
331331
PatternRewriter &rewriter) const final {
332332
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
333-
auto operand = op.getOperand();
333+
Value operand = op.getOperand();
334334
Type operandTy = operand.getType();
335335
Type resultTy = op.getType();
336336
Type operandETy = getElementTypeOrSelf(operandTy);
337337
Type resultETy = getElementTypeOrSelf(resultTy);
338338

339-
if (!operandETy.isF8E8M0FNU()) {
339+
if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
340340
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
341341
}
342342

343-
if (!resultETy.isBF16() && !resultETy.isF16() && !resultETy.isF32()) {
344-
return rewriter.notifyMatchFailure(
345-
op, "not a ext of F8M0FNU on a larger 16-bit or 32-bit width float.");
346-
}
347-
348343
Type i8Ty = b.getI8Type();
349344
Type i32Ty = b.getI32Type();
350345
Type f32Ty = b.getF32Type();
@@ -368,10 +363,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
368363
// select for NaNs
369364
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
370365
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
371-
if (resultETy.isBF16()) {
372-
result = b.create<arith::TruncFOp>(resultTy, result);
373-
} else if (resultETy.isF16()) {
366+
if (resultETy.getIntOrFloatBitWidth() < 32) {
374367
result = b.create<arith::TruncFOp>(resultTy, result);
368+
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
369+
result = b.create<arith::ExtFOp>(resultTy, result);
375370
}
376371
rewriter.replaceOp(op, result);
377372
return success();
@@ -388,18 +383,14 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
388383
LogicalResult matchAndRewrite(arith::TruncFOp op,
389384
PatternRewriter &rewriter) const final {
390385
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
391-
auto operand = op.getOperand();
386+
Value operand = op.getOperand();
392387
Type operandTy = operand.getType();
393388
Type operandETy = getElementTypeOrSelf(operandTy);
394389
Type resultTy = op.getType();
395390
Type resultETy = getElementTypeOrSelf(resultTy);
396-
if (!resultETy.isF8E8M0FNU()) {
391+
if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
397392
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
398393
}
399-
if (!operandETy.isBF16() && !operandETy.isF16() && !operandETy.isF32()) {
400-
return rewriter.notifyMatchFailure(
401-
op, "not a truncf of 16-bit or 32-bit float to f8E8M0FNU.");
402-
}
403394

404395
if (op.getRoundingmodeAttr()) {
405396
return rewriter.notifyMatchFailure(
@@ -414,8 +405,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
414405
i32Ty = shapedTy.clone(i32Ty);
415406
f32Ty = shapedTy.clone(f32Ty);
416407
}
417-
if (!operandETy.isF32()) {
408+
if (operandETy.getIntOrFloatBitWidth() < 32) {
418409
operand = b.create<arith::ExtFOp>(f32Ty, operand);
410+
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
411+
operand = b.create<arith::TruncFOp>(f32Ty, operand);
419412
}
420413
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
421414
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
@@ -453,36 +446,37 @@ struct ArithExpandOpsPass
453446
arith::MinNumFOp
454447
>();
455448

456-
if(includeBf16) {
449+
if (includeBf16) {
457450
arith::populateExpandBFloat16Patterns(patterns);
458451
}
459-
if(includeF8E8M0) {
452+
if (includeF8E8M0) {
460453
arith::populateExpandF8E8M0Patterns(patterns);
461454
}
462-
if (includeBf16 || includeF8E8M0) {
463-
target.addDynamicallyLegalOp<arith::ExtFOp>(
464-
[=](arith::ExtFOp op) {
465-
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
466-
Type outETy = getElementTypeOrSelf(op.getType());
467-
if(includeBf16 && includeF8E8M0)
468-
return !(inETy.isBF16() && outETy.isF32()) && !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
469-
if(includeBf16)
470-
return !(inETy.isBF16() && outETy.isF32());
471-
return !(inETy.isF8E8M0FNU() && (outETy.isF32() || outETy.isBF16() || outETy.isF16()));
472-
});
473-
474-
target.addDynamicallyLegalOp<arith::TruncFOp>(
475-
[=](arith::TruncFOp op) {
476-
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
477-
Type outETy = getElementTypeOrSelf(op.getType());
478-
if(includeBf16 && includeF8E8M0)
479-
return !(inETy.isF32() && outETy.isBF16()) && !(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
480-
if(includeBf16)
481-
return !(inETy.isF32() && outETy.isBF16());
482-
return
483-
!(outETy.isF8E8M0FNU() && (inETy.isF32() || inETy.isF16() || inETy.isBF16()));
484-
});
485-
}
455+
456+
target.addDynamicallyLegalOp<arith::ExtFOp>(
457+
[=](arith::ExtFOp op) {
458+
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
459+
Type outETy = getElementTypeOrSelf(op.getType());
460+
bool legalTypes = true;
461+
if(includeBf16)
462+
legalTypes &= !(inETy.isBF16() && outETy.isF32());
463+
if(includeF8E8M0)
464+
legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
465+
return legalTypes;
466+
});
467+
468+
target.addDynamicallyLegalOp<arith::TruncFOp>(
469+
[=](arith::TruncFOp op) {
470+
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
471+
Type outETy = getElementTypeOrSelf(op.getType());
472+
bool legalTypes = true;
473+
if(includeBf16)
474+
legalTypes &= !(inETy.isF32() && outETy.isBF16());
475+
if(includeF8E8M0)
476+
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
477+
return legalTypes;
478+
});
479+
486480
// clang-format on
487481
if (failed(applyPartialConversion(getOperation(), target,
488482
std::move(patterns))))

mlir/lib/IR/Types.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ Type AbstractType::replaceImmediateSubElements(Type type,
3333
//===----------------------------------------------------------------------===//
3434

3535
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
36-
bool Type::isF8E8M0FNU() const { return llvm::isa<Float8E8M0FNUType>(*this); }
3736
bool Type::isBF16() const { return llvm::isa<BFloat16Type>(*this); }
3837
bool Type::isF16() const { return llvm::isa<Float16Type>(*this); }
3938
bool Type::isTF32() const { return llvm::isa<FloatTF32Type>(*this); }

0 commit comments

Comments
 (0)