Skip to content

Commit 996b3e5

Browse files
authored
feat: defer analysis passes + mark ir with analysis results (#1482)
* feat: defer analysis passes + mark ir with analysis results * fix: use modify op in place
1 parent cea9c6b commit 996b3e5

File tree

9 files changed

+234
-119
lines changed

9 files changed

+234
-119
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6601,7 +6601,7 @@ struct NoNanSelfSubSimplify
66016601
PatternRewriter &rewriter) const {
66026602
if (op.getLhs() == op.getRhs()) {
66036603
if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(),
6604-
op.getLhs().getType(), op)) {
6604+
op.getLhs().getType(), op, rewriter)) {
66056605
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
66066606
op, rewriter.getZeroAttr(op.getType()));
66076607
return success();
@@ -6811,21 +6811,24 @@ struct NoNanDivSimplify final
68116811

68126812
LogicalResult matchAndRewriteImpl(stablehlo::DivOp op,
68136813
PatternRewriter &rewriter) const {
6814-
if (!canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(), op))
6815-
return failure();
6816-
68176814
// 0 / x -> 0
68186815
if (matchPattern(op.getLhs(), m_AnyZeroFloat()) ||
68196816
matchPattern(op.getLhs(), m_Zero())) {
6820-
rewriter.replaceOp(op, op.getLhs());
6821-
return success();
6817+
if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(), op,
6818+
rewriter)) {
6819+
rewriter.replaceOp(op, op.getLhs());
6820+
return success();
6821+
}
68226822
}
68236823

68246824
// x / x -> 1
68256825
if (op.getLhs() == op.getRhs()) {
6826-
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
6827-
op, op.getType(), cast<ElementsAttr>(makeAttr(op.getType(), 1)));
6828-
return success();
6826+
if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(), op,
6827+
rewriter)) {
6828+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
6829+
op, op.getType(), cast<ElementsAttr>(makeAttr(op.getType(), 1)));
6830+
return success();
6831+
}
68296832
}
68306833

68316834
return failure();
@@ -6929,17 +6932,17 @@ struct NoNanZeroBasePowSimplify final
69296932

69306933
LogicalResult matchAndRewriteImpl(stablehlo::PowOp op,
69316934
PatternRewriter &rewriter) const {
6932-
if (!canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(), op)) {
6933-
return failure();
6934-
}
6935-
69366935
if (matchPattern(op.getLhs(), m_Zero()) ||
69376936
matchPattern(op.getLhs(), m_AnyZeroFloat())) {
69386937

69396938
DenseElementsAttr attr;
69406939
if (matchPattern(op.getRhs(), m_Constant(&attr)))
69416940
return failure(); // let constant propagation handle this
69426941

6942+
if (!canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(), op,
6943+
rewriter))
6944+
return failure();
6945+
69436946
// 0 ^ x => x == 0 ? 1 : (x > 0 ? 0 : Inf)
69446947
auto zero = rewriter.create<stablehlo::ConstantOp>(
69456948
op.getLoc(), rewriter.getZeroAttr(op.getType()));
@@ -8537,14 +8540,17 @@ struct NoNanCompareSimplify
85378540
LogicalResult matchAndRewriteImpl(stablehlo::CompareOp op,
85388541
PatternRewriter &rewriter) const {
85398542
if (op.getLhs() == op.getRhs()) {
8540-
if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(),
8541-
op.getLhs().getType(), op)) {
8542-
if (op.getComparisonDirection() == stablehlo::ComparisonDirection::EQ) {
8543+
if (op.getComparisonDirection() == stablehlo::ComparisonDirection::EQ) {
8544+
if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(),
8545+
op.getLhs().getType(), op, rewriter)) {
85438546
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
85448547
op, op.getType(), cast<ElementsAttr>(makeAttr(op.getType(), 1)));
85458548
return success();
85468549
}
8547-
if (op.getComparisonDirection() == stablehlo::ComparisonDirection::NE) {
8550+
}
8551+
if (op.getComparisonDirection() == stablehlo::ComparisonDirection::NE) {
8552+
if (canApplyNoNanPattern(allowOnFloatingPointMath, op.getType(),
8553+
op.getLhs().getType(), op, rewriter)) {
85488554
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
85498555
op, op.getType(), cast<ElementsAttr>(makeAttr(op.getType(), 0)));
85508556
return success();
@@ -16412,27 +16418,29 @@ struct NoNanMulSimplify final
1641216418

1641316419
LogicalResult matchAndRewriteImpl(stablehlo::MulOp op,
1641416420
PatternRewriter &rewriter) const {
16415-
if (!canApplyNoNanPattern(allowOnFloatingPointMath,
16416-
op.getResult().getType(),
16417-
op.getOperand(0).getType(), op)) {
16418-
return failure();
16419-
}
16420-
1642116421
// 0 * x -> 0
1642216422
if (matchPattern(op.getLhs(), m_AnyZeroFloat()) ||
1642316423
matchPattern(op.getLhs(), m_Zero()) ||
1642416424
matchPattern(op.getLhs(), m_AnyZeroComplex())) {
16425-
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
16426-
op, cast<ElementsAttr>(makeAttr(op.getType(), 0)));
16427-
return success();
16425+
if (canApplyNoNanPattern(allowOnFloatingPointMath,
16426+
op.getResult().getType(),
16427+
op.getOperand(0).getType(), op, rewriter)) {
16428+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
16429+
op, cast<ElementsAttr>(makeAttr(op.getType(), 0)));
16430+
return success();
16431+
}
1642816432
}
1642916433
// x * 0 -> 0
1643016434
if (matchPattern(op.getRhs(), m_AnyZeroFloat()) ||
1643116435
matchPattern(op.getRhs(), m_Zero()) ||
1643216436
matchPattern(op.getRhs(), m_AnyZeroComplex())) {
16433-
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
16434-
op, cast<ElementsAttr>(makeAttr(op.getType(), 0)));
16435-
return success();
16437+
if (canApplyNoNanPattern(allowOnFloatingPointMath,
16438+
op.getResult().getType(),
16439+
op.getOperand(0).getType(), op, rewriter)) {
16440+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
16441+
op, cast<ElementsAttr>(makeAttr(op.getType(), 0)));
16442+
return success();
16443+
}
1643616444
}
1643716445

1643816446
return failure();
@@ -16456,40 +16464,46 @@ struct NoNanAddSubSimplify final
1645616464
// Check if LHS is defined by an AddOp
1645716465
if (auto lhsAddOp = lhs.getDefiningOp<stablehlo::AddOp>()) {
1645816466
auto addOutTy = lhsAddOp.getResult().getType();
16459-
if (!canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16460-
op))
16461-
return failure();
1646216467

1646316468
// Case: c = a + b; d = c - b -> d = a
1646416469
if (lhsAddOp.getRhs() == rhs) {
16465-
rewriter.replaceOp(op, lhsAddOp.getLhs());
16466-
return success();
16470+
if (canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16471+
op, rewriter)) {
16472+
rewriter.replaceOp(op, lhsAddOp.getLhs());
16473+
return success();
16474+
}
1646716475
}
1646816476

1646916477
// Case: c = a + b; d = c - a -> d = b
1647016478
if (lhsAddOp.getLhs() == rhs) {
16471-
rewriter.replaceOp(op, lhsAddOp.getRhs());
16472-
return success();
16479+
if (canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16480+
op, rewriter)) {
16481+
rewriter.replaceOp(op, lhsAddOp.getRhs());
16482+
return success();
16483+
}
1647316484
}
1647416485
}
1647516486

1647616487
// Check if RHS is defined by an AddOp
1647716488
if (auto rhsAddOp = rhs.getDefiningOp<stablehlo::AddOp>()) {
1647816489
auto addOutTy = rhsAddOp.getResult().getType();
16479-
if (!canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16480-
op))
16481-
return failure();
1648216490

1648316491
// Case: c = a + b; d = b - c -> d = -a
1648416492
if (rhsAddOp.getLhs() == lhs) {
16485-
rewriter.replaceOpWithNewOp<stablehlo::NegOp>(op, rhsAddOp.getRhs());
16486-
return success();
16493+
if (canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16494+
op, rewriter)) {
16495+
rewriter.replaceOpWithNewOp<stablehlo::NegOp>(op, rhsAddOp.getRhs());
16496+
return success();
16497+
}
1648716498
}
1648816499

1648916500
// Case: c = a + b; d = a - c -> d = -b
1649016501
if (rhsAddOp.getRhs() == lhs) {
16491-
rewriter.replaceOpWithNewOp<stablehlo::NegOp>(op, rhsAddOp.getLhs());
16492-
return success();
16502+
if (canApplyNoNanPattern(allowOnFloatingPointMath, addOutTy, subOutTy,
16503+
op, rewriter)) {
16504+
rewriter.replaceOpWithNewOp<stablehlo::NegOp>(op, rhsAddOp.getLhs());
16505+
return success();
16506+
}
1649316507
}
1649416508
}
1649516509

@@ -17185,7 +17199,7 @@ struct AbsPositiveSimplify
1718517199
if (isa<ComplexType>(operand.getType().getElementType()))
1718617200
return failure();
1718717201

17188-
if (guaranteedNonNegativeResult(operand.getDefiningOp())) {
17202+
if (guaranteedNonNegativeResult(operand.getDefiningOp(), rewriter)) {
1718917203
rewriter.replaceOp(op, op.getOperand());
1719017204
return success();
1719117205
}

0 commit comments

Comments
 (0)