Skip to content

Commit

Permalink
[PlutoTransform] canonicalize and segfault fix
Browse files Browse the repository at this point in the history
[PlutoTransform] canonicalize after dedup cast

[PlutoTransform] fix dedup segfault
  • Loading branch information
kumasento committed Nov 15, 2021
1 parent e48c8d4 commit 77dec25
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
8 changes: 4 additions & 4 deletions lib/Target/OpenScop/ConvertToOpenScop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,6 @@ void OslScopBuilder::buildScopContext(OslScop *scop,
FlatAffineValueConstraints *domain = it.second.getDomain();
FlatAffineValueConstraints cst(*domain);

ctx.mergeAndAlignIdsWithOther(0, &cst);
ctx.append(cst);
ctx.removeRedundantConstraints();

LLVM_DEBUG(dbgs() << "Statement:\n");
LLVM_DEBUG(it.second.getCaller().dump());
LLVM_DEBUG(it.second.getCallee().dump());
Expand All @@ -251,6 +247,10 @@ void OslScopBuilder::buildScopContext(OslScop *scop,
dbgs() << " * " << value << '\n';
});

ctx.mergeAndAlignIdsWithOther(0, &cst);
ctx.append(cst);
ctx.removeRedundantConstraints();

LLVM_DEBUG(dbgs() << "Updated context: \n");
LLVM_DEBUG(ctx.dump());

Expand Down
58 changes: 35 additions & 23 deletions lib/Transforms/PlutoTransform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,28 +132,6 @@ static mlir::FuncOp plutoTransform(mlir::FuncOp f, OpBuilder &rewriter,
return g;
}

static void dedupIndexCast(FuncOp f) {
Block &entry = f.getBlocks().front();
llvm::MapVector<Value, Value> argToCast;
SmallVector<Operation *> toErase;
for (auto &op : entry) {
if (auto indexCast = dyn_cast<arith::IndexCastOp>(&op)) {
auto arg = indexCast.getOperand().dyn_cast<BlockArgument>();
if (argToCast.count(arg)) {
LLVM_DEBUG(dbgs() << "Found duplicated index_cast: " << indexCast
<< '\n');
indexCast.replaceAllUsesWith(argToCast.lookup(arg));
toErase.push_back(indexCast);
} else {
argToCast[arg] = indexCast;
}
}
}

for (auto op : toErase)
op->erase();
}

namespace {
class PlutoTransformPass
: public mlir::PassWrapper<PlutoTransformPass,
Expand Down Expand Up @@ -183,7 +161,6 @@ class PlutoTransformPass

m.walk([&](mlir::FuncOp f) {
if (!f->getAttr("scop.stmt") && !f->hasAttr("scop.ignored")) {
dedupIndexCast(f);
funcOps.push_back(f);
}
});
Expand Down Expand Up @@ -300,10 +277,45 @@ struct PlutoParallelizePass
};
} // namespace

static void dedupIndexCast(FuncOp f) {
if (f.getBlocks().empty())
return;

Block &entry = f.getBlocks().front();
llvm::MapVector<Value, Value> argToCast;
SmallVector<Operation *> toErase;
for (auto &op : entry) {
if (auto indexCast = dyn_cast<arith::IndexCastOp>(&op)) {
auto arg = indexCast.getOperand().dyn_cast<BlockArgument>();
if (argToCast.count(arg)) {
LLVM_DEBUG(dbgs() << "Found duplicated index_cast: " << indexCast
<< '\n');
indexCast.replaceAllUsesWith(argToCast.lookup(arg));
toErase.push_back(indexCast);
} else {
argToCast[arg] = indexCast;
}
}
}

for (auto op : toErase)
op->erase();
}

namespace {
struct DedupIndexCastPass
: public mlir::PassWrapper<DedupIndexCastPass,
OperationPass<mlir::FuncOp>> {
void runOnOperation() override { dedupIndexCast(getOperation()); }
};
} // namespace

void polymer::registerPlutoTransformPass() {
PassPipelineRegistration<PlutoOptPipelineOptions>(
"pluto-opt", "Optimization implemented by PLUTO.",
[](OpPassManager &pm, const PlutoOptPipelineOptions &pipelineOptions) {
pm.addPass(std::make_unique<DedupIndexCastPass>());
pm.addPass(createCanonicalizerPass());
pm.addPass(std::make_unique<PlutoTransformPass>(pipelineOptions));
pm.addPass(createCanonicalizerPass());
if (pipelineOptions.generateParallel) {
Expand Down

0 comments on commit 77dec25

Please sign in to comment.