Skip to content

Commit 68cb5b0

Browse files
authored
Separate out while licm (EnzymeAD#648)
* Separate out while licm * fmt * more infra * fix
1 parent ab3da6e commit 68cb5b0

File tree

5 files changed

+103
-2
lines changed

5 files changed

+103
-2
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11394,8 +11394,80 @@ struct WhileSimplify : public OpRewritePattern<stablehlo::WhileOp> {
1139411394
cond->eraseArgument(i);
1139511395

1139611396
deleted++;
11397-
} else if (canHoist && definedOutside(bodyRes, op) && ivInfo.isValid &&
11398-
ivInfo.step != 0) {
11397+
} else {
11398+
operands.push_back(opOperand.getOperandNumber());
11399+
}
11400+
}
11401+
11402+
if (operands.size() == op->getNumOperands())
11403+
return failure();
11404+
11405+
SmallVector<Value> newOperands;
11406+
newOperands.reserve(operands.size());
11407+
11408+
for (auto opOperand : operands) {
11409+
newOperands.push_back(op->getOperand(opOperand));
11410+
}
11411+
11412+
auto newWhile =
11413+
rewriter.create<stablehlo::WhileOp>(op.getLoc(), newOperands);
11414+
newWhile.getCond().takeBody(op.getCond());
11415+
newWhile.getBody().takeBody(op.getBody());
11416+
11417+
// Replace uses for remaining results.
11418+
for (const auto &it : llvm::enumerate(operands)) {
11419+
Value oldRes = op->getResult(it.value());
11420+
Value newRes = newWhile->getResult(it.index());
11421+
11422+
rewriter.replaceAllUsesWith(oldRes, newRes);
11423+
}
11424+
11425+
rewriter.eraseOp(op);
11426+
11427+
return success();
11428+
}
11429+
};
11430+
11431+
// Replace while op iteration variables which are not updated with their
11432+
// upcoming value
11433+
struct WhileLICM : public OpRewritePattern<stablehlo::WhileOp> {
11434+
using OpRewritePattern::OpRewritePattern;
11435+
bool hoist_all;
11436+
WhileLICM(bool hoist_all, MLIRContext *context, PatternBenefit benefit = 1,
11437+
ArrayRef<StringRef> generatedNames = {})
11438+
: OpRewritePattern(context, benefit, generatedNames),
11439+
hoist_all(hoist_all) {}
11440+
11441+
LogicalResult matchAndRewrite(stablehlo::WhileOp op,
11442+
PatternRewriter &rewriter) const override {
11443+
SmallVector<unsigned> operands;
11444+
11445+
Block *cond = &op.getCond().front(), *body = &op.getBody().front();
11446+
Operation *bodyTerm = body->getTerminator();
11447+
11448+
int deleted = 0;
11449+
11450+
// Find the index of IV and the step to check for 1 iteration
11451+
auto ivInfo = extractSimpleIVInfo(op);
11452+
11453+
for (auto &opOperand : op->getOpOperands()) {
11454+
Value inputValue = opOperand.get();
11455+
11456+
auto i = opOperand.getOperandNumber() - deleted;
11457+
Value bodyArg = body->getArgument(i);
11458+
Value condArg = cond->getArgument(i);
11459+
11460+
bool canHoist = inputValue.getDefiningOp<stablehlo::ConstantOp>();
11461+
if (auto BA = dyn_cast<BlockArgument>(inputValue)) {
11462+
canHoist |= isa<FunctionOpInterface>(BA.getOwner()->getParentOp());
11463+
} else if (hoist_all) {
11464+
canHoist = true;
11465+
}
11466+
11467+
Value bodyRes = bodyTerm->getOperand(i);
11468+
11469+
if (canHoist && definedOutside(bodyRes, op) && ivInfo.isValid &&
11470+
ivInfo.step != 0) {
1139911471

1140011472
Value resultReplacement;
1140111473
{
@@ -13070,6 +13142,12 @@ void mlir::transform::addWhileSimplify(RewritePatternSet &patterns,
1307013142
patterns.insert<WhileSimplify>(hoistAll, &context, benefit);
1307113143
}
1307213144

13145+
void mlir::transform::addWhileLICM(RewritePatternSet &patterns, bool hoistAll,
13146+
MLIRContext &context,
13147+
PatternBenefit benefit) {
13148+
patterns.insert<WhileLICM>(hoistAll, &context, benefit);
13149+
}
13150+
1307313151
void mlir::transform::addSliceLICM(RewritePatternSet &patterns,
1307413152
bool single_user, MLIRContext &context,
1307513153
PatternBenefit benefit) {
@@ -13381,6 +13459,8 @@ struct EnzymeHLOOptPass
1338113459

1338213460
patterns.add<WhileSimplify>(false, context);
1338313461

13462+
patterns.add<WhileLICM>(false, context);
13463+
1338413464
// clang-format on
1338513465
patterns.add<SelectOpCanon>(max_constant_expansion, context,
1338613466
PatternBenefit(65000));

src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ void addIotaSimplify(RewritePatternSet &patterns, int64_t maxConstantExpansion,
2626
MLIRContext &context, PatternBenefit benefit);
2727
void addWhileSimplify(RewritePatternSet &patterns, bool hoist_all,
2828
MLIRContext &context, PatternBenefit benefit);
29+
void addWhileLICM(RewritePatternSet &patterns, bool hoist_all,
30+
MLIRContext &context, PatternBenefit benefit);
2931
void addSliceLICM(RewritePatternSet &patterns, bool single_user,
3032
MLIRContext &context, PatternBenefit benefit);
3133
void addDUSLICM(RewritePatternSet &patterns, bool single_user,

src/enzyme_ad/jax/TransformOps/TransformOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ void ApplyWhileSimplifyPatterns::populatePatterns(RewritePatternSet &patterns) {
3636
addWhileSimplify(patterns, getParameter(), *getContext(),
3737
PatternBenefit(getBenefit().value_or(0)));
3838
}
39+
void ApplyWhileLICMPatterns::populatePatterns(RewritePatternSet &patterns) {
40+
addWhileLICM(patterns, getParameter(), *getContext(),
41+
PatternBenefit(getBenefit().value_or(0)));
42+
}
3943
void ApplySliceLICMPatterns::populatePatterns(RewritePatternSet &patterns) {
4044
addSliceLICM(patterns, getParameter(), *getContext(),
4145
PatternBenefit(getBenefit().value_or(1)));

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,21 @@ def ApplyWhileSimplifyPatterns : EnzymeHLOParameterizedPatternOp<
795795
}];
796796
}
797797

798+
def ApplyWhileLICMPatterns : EnzymeHLOParameterizedPatternOp<
799+
"while_licm"> {
800+
let arguments = (ins OptionalAttr<I64Attr>:$benefit, BoolAttr:$parameter);
801+
let assemblyFormat = "attr-dict";
802+
// TODO: this should be made better searchable.
803+
let extraClassDeclaration = [{
804+
::llvm::SmallVector<::mlir::DictionaryAttr>
805+
static getPossibleAttrCombinations(::mlir::Builder &builder) {
806+
return {builder.getDictionaryAttr(
807+
builder.getNamedAttr("parameter",
808+
builder.getBoolAttr(false)))};
809+
}
810+
}];
811+
}
812+
798813
def ApplySliceLICMPatterns : EnzymeHLOParameterizedPatternOp<
799814
"slice_licm"> {
800815
let arguments = (ins OptionalAttr<I64Attr>:$benefit, BoolAttr:$parameter);
File renamed without changes.

0 commit comments

Comments
 (0)