@@ -6601,7 +6601,7 @@ struct NoNanSelfSubSimplify
66016601 PatternRewriter &rewriter) const {
66026602 if (op.getLhs() == op.getRhs()) {
66036603 if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(),
6604- op.getLhs().getType(), op)) {
6604+ op.getLhs().getType(), op, rewriter )) {
66056605 rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
66066606 op, rewriter.getZeroAttr(op.getType()));
66076607 return success();
@@ -6811,21 +6811,24 @@ struct NoNanDivSimplify final
68116811
68126812 LogicalResult matchAndRewriteImpl(stablehlo::DivOp op,
68136813 PatternRewriter &rewriter) const {
6814- if (!canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(), op))
6815- return failure();
6816-
68176814 // 0 / x -> 0
68186815 if (matchPattern(op.getLhs(), m_AnyZeroFloat()) ||
68196816 matchPattern(op.getLhs(), m_Zero())) {
6820- rewriter.replaceOp(op, op.getLhs());
6821- return success();
6817+ if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(), op,
6818+ rewriter)) {
6819+ rewriter.replaceOp(op, op.getLhs());
6820+ return success();
6821+ }
68226822 }
68236823
68246824 // x / x -> 1
68256825 if (op.getLhs() == op.getRhs()) {
6826- rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
6827- op, op.getType(), cast<ElementsAttr>(makeAttr(op.getType(), 1)));
6828- return success();
6826+ if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(), op,
6827+ rewriter)) {
6828+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
6829+ op, op.getType(), cast<ElementsAttr>(makeAttr(op.getType(), 1)));
6830+ return success();
6831+ }
68296832 }
68306833
68316834 return failure();
@@ -6929,17 +6932,17 @@ struct NoNanZeroBasePowSimplify final
69296932
69306933 LogicalResult matchAndRewriteImpl(stablehlo::PowOp op,
69316934 PatternRewriter &rewriter) const {
6932- if (!canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(), op)) {
6933- return failure();
6934- }
6935-
69366935 if (matchPattern(op.getLhs(), m_Zero()) ||
69376936 matchPattern(op.getLhs(), m_AnyZeroFloat())) {
69386937
69396938 DenseElementsAttr attr;
69406939 if (matchPattern(op.getRhs(), m_Constant(&attr)))
69416940 return failure(); // let constant propagation handle this
69426941
6942+ if (!canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(), op,
6943+ rewriter))
6944+ return failure();
6945+
69436946 // 0 ^ x => x == 0 ? 1 : (x > 0 ? 0 : Inf)
69446947 auto zero = rewriter.create<stablehlo::ConstantOp>(
69456948 op.getLoc(), rewriter.getZeroAttr(op.getType()));
@@ -8537,14 +8540,17 @@ struct NoNanCompareSimplify
85378540 LogicalResult matchAndRewriteImpl(stablehlo::CompareOp op,
85388541 PatternRewriter &rewriter) const {
85398542 if (op.getLhs() == op.getRhs()) {
8540- if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(),
8541- op.getLhs(). getType(), op)) {
8542- if ( op.getComparisonDirection() == stablehlo::ComparisonDirection::EQ ) {
8543+ if (op.getComparisonDirection() == stablehlo::ComparisonDirection::EQ) {
8544+ if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(),
8545+ op.getLhs().getType(), op, rewriter) ) {
85438546 rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
85448547 op, op.getType(), cast<ElementsAttr>(makeAttr(op.getType(), 1)));
85458548 return success();
85468549 }
8547- if (op.getComparisonDirection() == stablehlo::ComparisonDirection::NE) {
8550+ }
8551+ if (op.getComparisonDirection() == stablehlo::ComparisonDirection::NE) {
8552+ if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(),
8553+ op.getLhs().getType(), op, rewriter)) {
85488554 rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
85498555 op, op.getType(), cast<ElementsAttr>(makeAttr(op.getType(), 0)));
85508556 return success();
@@ -16412,27 +16418,29 @@ struct NoNanMulSimplify final
1641216418
1641316419 LogicalResult matchAndRewriteImpl(stablehlo::MulOp op,
1641416420 PatternRewriter &rewriter) const {
16415- if (!canApplyNoNanPattern(allowOnFloatingPointMath,
16416- op.getResult().getType(),
16417- op.getOperand(0).getType(), op)) {
16418- return failure();
16419- }
16420-
1642116421 // 0 * x -> 0
1642216422 if (matchPattern(op.getLhs(), m_AnyZeroFloat()) ||
1642316423 matchPattern(op.getLhs(), m_Zero()) ||
1642416424 matchPattern(op.getLhs(), m_AnyZeroComplex())) {
16425- rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
16426- op, cast<ElementsAttr>(makeAttr(op.getType(), 0)));
16427- return success();
16425+ if (canApplyNoNanPattern(allowOnFloatingPointMath,
16426+ op.getResult().getType(),
16427+ op.getOperand(0).getType(), op, rewriter)) {
16428+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
16429+ op, cast<ElementsAttr>(makeAttr(op.getType(), 0)));
16430+ return success();
16431+ }
1642816432 }
1642916433 // x * 0 -> 0
1643016434 if (matchPattern(op.getRhs(), m_AnyZeroFloat()) ||
1643116435 matchPattern(op.getRhs(), m_Zero()) ||
1643216436 matchPattern(op.getRhs(), m_AnyZeroComplex())) {
16433- rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
16434- op, cast<ElementsAttr>(makeAttr(op.getType(), 0)));
16435- return success();
16437+ if (canApplyNoNanPattern(allowOnFloatingPointMath,
16438+ op.getResult().getType(),
16439+ op.getOperand(0).getType(), op, rewriter)) {
16440+ rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
16441+ op, cast<ElementsAttr>(makeAttr(op.getType(), 0)));
16442+ return success();
16443+ }
1643616444 }
1643716445
1643816446 return failure();
@@ -16456,40 +16464,46 @@ struct NoNanAddSubSimplify final
1645616464 // Check if LHS is defined by an AddOp
1645716465 if (auto lhsAddOp = lhs.getDefiningOp<stablehlo::AddOp>()) {
1645816466 auto addOutTy = lhsAddOp.getResult().getType();
16459- if (!canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16460- op))
16461- return failure();
1646216467
1646316468 // Case: c = a + b; d = c - b -> d = a
1646416469 if (lhsAddOp.getRhs() == rhs) {
16465- rewriter.replaceOp(op, lhsAddOp.getLhs());
16466- return success();
16470+ if (canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16471+ op, rewriter)) {
16472+ rewriter.replaceOp(op, lhsAddOp.getLhs());
16473+ return success();
16474+ }
1646716475 }
1646816476
1646916477 // Case: c = a + b; d = c - a -> d = b
1647016478 if (lhsAddOp.getLhs() == rhs) {
16471- rewriter.replaceOp(op, lhsAddOp.getRhs());
16472- return success();
16479+ if (canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16480+ op, rewriter)) {
16481+ rewriter.replaceOp(op, lhsAddOp.getRhs());
16482+ return success();
16483+ }
1647316484 }
1647416485 }
1647516486
1647616487 // Check if RHS is defined by an AddOp
1647716488 if (auto rhsAddOp = rhs.getDefiningOp<stablehlo::AddOp>()) {
1647816489 auto addOutTy = rhsAddOp.getResult().getType();
16479- if (!canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16480- op))
16481- return failure();
1648216490
1648316491 // Case: c = a + b; d = b - c -> d = -a
1648416492 if (rhsAddOp.getLhs() == lhs) {
16485- rewriter.replaceOpWithNewOp<stablehlo::NegOp>(op, rhsAddOp.getRhs());
16486- return success();
16493+ if (canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16494+ op, rewriter)) {
16495+ rewriter.replaceOpWithNewOp<stablehlo::NegOp>(op, rhsAddOp.getRhs());
16496+ return success();
16497+ }
1648716498 }
1648816499
1648916500 // Case: c = a + b; d = a - c -> d = -b
1649016501 if (rhsAddOp.getRhs() == lhs) {
16491- rewriter.replaceOpWithNewOp<stablehlo::NegOp>(op, rhsAddOp.getLhs());
16492- return success();
16502+ if (canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16503+ op, rewriter)) {
16504+ rewriter.replaceOpWithNewOp<stablehlo::NegOp>(op, rhsAddOp.getLhs());
16505+ return success();
16506+ }
1649316507 }
1649416508 }
1649516509
@@ -17185,7 +17199,7 @@ struct AbsPositiveSimplify
1718517199 if (isa<ComplexType>(operand.getType().getElementType()))
1718617200 return failure();
1718717201
17188- if (guaranteedNonNegativeResult(operand.getDefiningOp())) {
17202+ if (guaranteedNonNegativeResult(operand.getDefiningOp(), rewriter )) {
1718917203 rewriter.replaceOp(op, op.getOperand());
1719017204 return success();
1719117205 }
0 commit comments