Skip to content

Commit f62994d

Browse files
committed
[CIR] Skip generation of a continue block when flattening TernaryOp
We used to insert a continue Block at the end of a flattened ternary op that only contained a branch to the remaing operation of the remaining Block. This patch removes that continue block and changes the true/false blocks to directly jump to the remaining ops. With this patch the CIR now generates exactly the same LLVM IR as the original codegen. This upstreams llvm/clangir#1651.
1 parent 597340b commit f62994d

File tree

3 files changed

+15
-22
lines changed

3 files changed

+15
-22
lines changed

clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
#include "mlir/IR/Block.h"
1717
#include "mlir/IR/Builders.h"
1818
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/IR/ValueRange.h"
1920
#include "mlir/Support/LogicalResult.h"
2021
#include "mlir/Transforms/DialectConversion.h"
2122
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23+
#include "clang/AST/DeclBase.h"
2224
#include "clang/CIR/Dialect/IR/CIRDialect.h"
2325
#include "clang/CIR/Dialect/Passes.h"
2426
#include "clang/CIR/MissingFeatures.h"
@@ -491,15 +493,7 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
491493
Location loc = op->getLoc();
492494
Block *condBlock = rewriter.getInsertionBlock();
493495
Block::iterator opPosition = rewriter.getInsertionPoint();
494-
Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
495-
llvm::SmallVector<mlir::Location, 2> locs;
496-
// Ternary result is optional, make sure to populate the location only
497-
// when relevant.
498-
if (op->getResultTypes().size())
499-
locs.push_back(loc);
500-
Block *continueBlock =
501-
rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
502-
rewriter.create<cir::BrOp>(loc, remainingOpsBlock);
496+
auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
503497

504498
Region &trueRegion = op.getTrueRegion();
505499
Block *trueBlock = &trueRegion.front();
@@ -508,24 +502,29 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
508502
auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
509503

510504
rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
511-
continueBlock);
512-
rewriter.inlineRegionBefore(trueRegion, continueBlock);
505+
remainingOpsBlock);
506+
rewriter.inlineRegionBefore(trueRegion, remainingOpsBlock);
513507

514-
Block *falseBlock = continueBlock;
515508
Region &falseRegion = op.getFalseRegion();
509+
Block *falseBlock = &falseRegion.front();
516510

517-
falseBlock = &falseRegion.front();
518511
mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
519512
rewriter.setInsertionPointToEnd(&falseRegion.back());
520513
auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
521514
rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(),
522-
continueBlock);
523-
rewriter.inlineRegionBefore(falseRegion, continueBlock);
515+
remainingOpsBlock);
516+
rewriter.inlineRegionBefore(falseRegion, remainingOpsBlock);
524517

525518
rewriter.setInsertionPointToEnd(condBlock);
526519
rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);
527520

528-
rewriter.replaceOp(op, continueBlock->getArguments());
521+
if (auto rt = op.getResultTypes(); rt.size()) {
522+
iterator_range args = remainingOpsBlock->addArguments(rt, op.getLoc());
523+
SmallVector<mlir::Value, 2> values;
524+
llvm::copy(args, std::back_inserter(values));
525+
rewriter.replaceOpUsesWithinBlock(op, values, remainingOpsBlock);
526+
}
527+
rewriter.eraseOp(op);
529528

530529
// Ok, we're done!
531530
return mlir::success();

clang/test/CIR/Lowering/ternary.cir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,4 @@ module {
2525
// LLVM: br label %[[M]]
2626
// LLVM: [[M]]:
2727
// LLVM: [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ]
28-
// LLVM: br label %[[B3:[[:alnum:]]+]]
29-
// LLVM: [[B3]]:
3028
// LLVM: ret i32 [[R]]

clang/test/CIR/Transforms/ternary.cir

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ module {
3737
// CHECK: %6 = cir.const #cir.int<5> : !s32i
3838
// CHECK: cir.br ^bb3(%6 : !s32i)
3939
// CHECK: ^bb3(%7: !s32i): // 2 preds: ^bb1, ^bb2
40-
// CHECK: cir.br ^bb4
41-
// CHECK: ^bb4: // pred: ^bb3
4240
// CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
4341
// CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
4442
// CHECK: cir.return %8 : !s32i
@@ -60,8 +58,6 @@ module {
6058
// CHECK: ^bb2: // pred: ^bb0
6159
// CHECK: cir.br ^bb3
6260
// CHECK: ^bb3: // 2 preds: ^bb1, ^bb2
63-
// CHECK: cir.br ^bb4
64-
// CHECK: ^bb4: // pred: ^bb3
6561
// CHECK: cir.return
6662
// CHECK: }
6763

0 commit comments

Comments
 (0)