Skip to content

Commit 2aa0578

Browse files
committed
Revert "Correct while induction (EnzymeAD#669)"
This reverts commit fa5353d.
1 parent fa5353d commit 2aa0578

File tree

1 file changed

+39
-4
lines changed

1 file changed

+39
-4
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11163,13 +11163,31 @@ struct WhileInductionReduction : public OpRewritePattern<stablehlo::WhileOp> {
1116311163
for (int i = 0; i < pair.lowerBounds.size(); i++) {
1116411164
update_starts.push_back(rewriter.create<stablehlo::ConstantOp>(
1116511165
pair.argOperand.getLoc(), itype,
11166-
makeAttr(itype, pair.lowerUpdateBounds[i])
11166+
makeAttr(itype,
11167+
pair.lowerUpdateBounds[i] - pair.lowerBounds[i])
1116711168
.cast<ElementsAttr>()));
1116811169
}
1116911170

1117011171
newArg = rewriter.create<stablehlo::DynamicUpdateSliceOp>(
11171-
pair.argOperand.getLoc(), pair.outerOperand, newArg,
11172+
pair.argOperand.getLoc(), pair.outerSlice, newArg,
1117211173
update_starts);
11174+
11175+
auto ctype = RankedTensorType::get(
11176+
{}, cast<RankedTensorType>(pair.argOperand.getType())
11177+
.getElementType());
11178+
auto padVal = rewriter.create<stablehlo::ConstantOp>(
11179+
pair.argOperand.getLoc(), ctype,
11180+
makeAttr(ctype, 0).cast<ElementsAttr>());
11181+
11182+
SmallVector<int64_t> slow = llvm::to_vector(pair.lowerBounds);
11183+
SmallVector<int64_t> shigh = llvm::to_vector(
11184+
cast<RankedTensorType>(pair.argOperand.getType()).getShape());
11185+
for (int i = 0; i < shigh.size(); i++)
11186+
shigh[i] -= pair.upperBounds[i];
11187+
SmallVector<int64_t> sint(shigh.size(), 0);
11188+
11189+
newArg = rewriter.create<stablehlo::PadOp>(
11190+
pair.argOperand.getLoc(), newArg, padVal, slow, shigh, sint);
1117311191
break;
1117411192
}
1117511193
}
@@ -11229,13 +11247,30 @@ struct WhileInductionReduction : public OpRewritePattern<stablehlo::WhileOp> {
1122911247
for (int i = 0; i < pair.lowerBounds.size(); i++) {
1123011248
update_starts.push_back(rewriter.create<stablehlo::ConstantOp>(
1123111249
pair.argOperand.getLoc(), itype,
11232-
makeAttr(itype, pair.lowerUpdateBounds[i])
11250+
makeAttr(itype,
11251+
pair.lowerUpdateBounds[i] - pair.lowerBounds[i])
1123311252
.cast<ElementsAttr>()));
1123411253
}
1123511254

1123611255
newArg = rewriter.create<stablehlo::DynamicUpdateSliceOp>(
11237-
pair.argOperand.getLoc(), pair.outerOperand, newArg,
11256+
pair.argOperand.getLoc(), pair.outerSlice, newArg,
1123811257
update_starts);
11258+
auto ctype = RankedTensorType::get(
11259+
{}, cast<RankedTensorType>(pair.condOperand.getType())
11260+
.getElementType());
11261+
auto padVal = rewriter.create<stablehlo::ConstantOp>(
11262+
pair.condOperand.getLoc(), ctype,
11263+
makeAttr(ctype, 0).cast<ElementsAttr>());
11264+
11265+
SmallVector<int64_t> slow = llvm::to_vector(pair.lowerBounds);
11266+
SmallVector<int64_t> shigh = llvm::to_vector(
11267+
cast<RankedTensorType>(pair.condOperand.getType()).getShape());
11268+
for (int i = 0; i < shigh.size(); i++)
11269+
shigh[i] -= pair.upperBounds[i];
11270+
SmallVector<int64_t> sint(shigh.size(), 0);
11271+
11272+
newArg = rewriter.create<stablehlo::PadOp>(
11273+
pair.condOperand.getLoc(), newArg, padVal, slow, shigh, sint);
1123911274
break;
1124011275
}
1124111276
}

0 commit comments

Comments
 (0)