@@ -11394,8 +11394,80 @@ struct WhileSimplify : public OpRewritePattern<stablehlo::WhileOp> {
11394
11394
cond->eraseArgument (i);
11395
11395
11396
11396
deleted++;
11397
- } else if (canHoist && definedOutside (bodyRes, op) && ivInfo.isValid &&
11398
- ivInfo.step != 0 ) {
11397
+ } else {
11398
+ operands.push_back (opOperand.getOperandNumber ());
11399
+ }
11400
+ }
11401
+
11402
+ if (operands.size () == op->getNumOperands ())
11403
+ return failure ();
11404
+
11405
+ SmallVector<Value> newOperands;
11406
+ newOperands.reserve (operands.size ());
11407
+
11408
+ for (auto opOperand : operands) {
11409
+ newOperands.push_back (op->getOperand (opOperand));
11410
+ }
11411
+
11412
+ auto newWhile =
11413
+ rewriter.create <stablehlo::WhileOp>(op.getLoc (), newOperands);
11414
+ newWhile.getCond ().takeBody (op.getCond ());
11415
+ newWhile.getBody ().takeBody (op.getBody ());
11416
+
11417
+ // Replace uses for remaining results.
11418
+ for (const auto &it : llvm::enumerate (operands)) {
11419
+ Value oldRes = op->getResult (it.value ());
11420
+ Value newRes = newWhile->getResult (it.index ());
11421
+
11422
+ rewriter.replaceAllUsesWith (oldRes, newRes);
11423
+ }
11424
+
11425
+ rewriter.eraseOp (op);
11426
+
11427
+ return success ();
11428
+ }
11429
+ };
11430
+
11431
+ // Replace while op iteration variables which are not updated with their
11432
+ // upcoming value
11433
+ struct WhileLICM : public OpRewritePattern <stablehlo::WhileOp> {
11434
+ using OpRewritePattern::OpRewritePattern;
11435
+ bool hoist_all;
11436
+ WhileLICM (bool hoist_all, MLIRContext *context, PatternBenefit benefit = 1 ,
11437
+ ArrayRef<StringRef> generatedNames = {})
11438
+ : OpRewritePattern(context, benefit, generatedNames),
11439
+ hoist_all (hoist_all) {}
11440
+
11441
+ LogicalResult matchAndRewrite (stablehlo::WhileOp op,
11442
+ PatternRewriter &rewriter) const override {
11443
+ SmallVector<unsigned > operands;
11444
+
11445
+ Block *cond = &op.getCond ().front (), *body = &op.getBody ().front ();
11446
+ Operation *bodyTerm = body->getTerminator ();
11447
+
11448
+ int deleted = 0 ;
11449
+
11450
+ // Find the index of IV and the step to check for 1 iteration
11451
+ auto ivInfo = extractSimpleIVInfo (op);
11452
+
11453
+ for (auto &opOperand : op->getOpOperands ()) {
11454
+ Value inputValue = opOperand.get ();
11455
+
11456
+ auto i = opOperand.getOperandNumber () - deleted;
11457
+ Value bodyArg = body->getArgument (i);
11458
+ Value condArg = cond->getArgument (i);
11459
+
11460
+ bool canHoist = inputValue.getDefiningOp <stablehlo::ConstantOp>();
11461
+ if (auto BA = dyn_cast<BlockArgument>(inputValue)) {
11462
+ canHoist |= isa<FunctionOpInterface>(BA.getOwner ()->getParentOp ());
11463
+ } else if (hoist_all) {
11464
+ canHoist = true ;
11465
+ }
11466
+
11467
+ Value bodyRes = bodyTerm->getOperand (i);
11468
+
11469
+ if (canHoist && definedOutside (bodyRes, op) && ivInfo.isValid &&
11470
+ ivInfo.step != 0 ) {
11399
11471
11400
11472
Value resultReplacement;
11401
11473
{
@@ -13070,6 +13142,12 @@ void mlir::transform::addWhileSimplify(RewritePatternSet &patterns,
13070
13142
patterns.insert <WhileSimplify>(hoistAll, &context, benefit);
13071
13143
}
13072
13144
13145
+ void mlir::transform::addWhileLICM (RewritePatternSet &patterns, bool hoistAll,
13146
+ MLIRContext &context,
13147
+ PatternBenefit benefit) {
13148
+ patterns.insert <WhileLICM>(hoistAll, &context, benefit);
13149
+ }
13150
+
13073
13151
void mlir::transform::addSliceLICM (RewritePatternSet &patterns,
13074
13152
bool single_user, MLIRContext &context,
13075
13153
PatternBenefit benefit) {
@@ -13381,6 +13459,8 @@ struct EnzymeHLOOptPass
13381
13459
13382
13460
patterns.add <WhileSimplify>(false , context);
13383
13461
13462
+ patterns.add <WhileLICM>(false , context);
13463
+
13384
13464
// clang-format on
13385
13465
patterns.add <SelectOpCanon>(max_constant_expansion, context,
13386
13466
PatternBenefit (65000 ));
0 commit comments