16
16
#include " mlir/IR/Block.h"
17
17
#include " mlir/IR/Builders.h"
18
18
#include " mlir/IR/PatternMatch.h"
19
+ #include " mlir/IR/ValueRange.h"
19
20
#include " mlir/Support/LogicalResult.h"
20
21
#include " mlir/Transforms/DialectConversion.h"
21
22
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
23
+ #include " clang/AST/DeclBase.h"
22
24
#include " clang/CIR/Dialect/IR/CIRDialect.h"
23
25
#include " clang/CIR/Dialect/Passes.h"
24
26
#include " clang/CIR/MissingFeatures.h"
@@ -491,15 +493,7 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
491
493
Location loc = op->getLoc ();
492
494
Block *condBlock = rewriter.getInsertionBlock ();
493
495
Block::iterator opPosition = rewriter.getInsertionPoint ();
494
- Block *remainingOpsBlock = rewriter.splitBlock (condBlock, opPosition);
495
- llvm::SmallVector<mlir::Location, 2 > locs;
496
- // Ternary result is optional, make sure to populate the location only
497
- // when relevant.
498
- if (op->getResultTypes ().size ())
499
- locs.push_back (loc);
500
- Block *continueBlock =
501
- rewriter.createBlock (remainingOpsBlock, op->getResultTypes (), locs);
502
- rewriter.create <cir::BrOp>(loc, remainingOpsBlock);
496
+ auto *remainingOpsBlock = rewriter.splitBlock (condBlock, opPosition);
503
497
504
498
Region &trueRegion = op.getTrueRegion ();
505
499
Block *trueBlock = &trueRegion.front ();
@@ -508,24 +502,29 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
508
502
auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
509
503
510
504
rewriter.replaceOpWithNewOp <cir::BrOp>(trueYieldOp, trueYieldOp.getArgs (),
511
- continueBlock );
512
- rewriter.inlineRegionBefore (trueRegion, continueBlock );
505
+ remainingOpsBlock );
506
+ rewriter.inlineRegionBefore (trueRegion, remainingOpsBlock );
513
507
514
- Block *falseBlock = continueBlock;
515
508
Region &falseRegion = op.getFalseRegion ();
509
+ Block *falseBlock = &falseRegion.front ();
516
510
517
- falseBlock = &falseRegion.front ();
518
511
mlir::Operation *falseTerminator = falseRegion.back ().getTerminator ();
519
512
rewriter.setInsertionPointToEnd (&falseRegion.back ());
520
513
auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
521
514
rewriter.replaceOpWithNewOp <cir::BrOp>(falseYieldOp, falseYieldOp.getArgs (),
522
- continueBlock );
523
- rewriter.inlineRegionBefore (falseRegion, continueBlock );
515
+ remainingOpsBlock );
516
+ rewriter.inlineRegionBefore (falseRegion, remainingOpsBlock );
524
517
525
518
rewriter.setInsertionPointToEnd (condBlock);
526
519
rewriter.create <cir::BrCondOp>(loc, op.getCond (), trueBlock, falseBlock);
527
520
528
- rewriter.replaceOp (op, continueBlock->getArguments ());
521
+ if (auto rt = op.getResultTypes (); rt.size ()) {
522
+ iterator_range args = remainingOpsBlock->addArguments (rt, op.getLoc ());
523
+ SmallVector<mlir::Value, 2 > values;
524
+ llvm::copy (args, std::back_inserter (values));
525
+ rewriter.replaceOpUsesWithinBlock (op, values, remainingOpsBlock);
526
+ }
527
+ rewriter.eraseOp (op);
529
528
530
529
// Ok, we're done!
531
530
return mlir::success ();
0 commit comments