Skip to content

[mlir][SliceAnalysis] Fix stack overflow in graph regions #139694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mlir/include/mlir/Analysis/SliceAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ using ForwardSliceOptions = SliceOptions;
///
/// The implementation traverses the use chains in postorder traversal for
/// efficiency reasons: if an operation is already in `forwardSlice`, no
/// need to traverse its uses again. Since use-def chains form a DAG, this
/// terminates.
/// need to traverse its uses again. In the presence of use-def cycles in a
/// graph region, the traversal stops at the first operation that was already
/// visited (which is not added to the slice anymore).
///
/// Upon return to the root call, `forwardSlice` is filled with a
/// postorder list of uses (i.e. a reverse topological order). To get a proper
Expand Down Expand Up @@ -114,8 +115,9 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
///
/// The implementation traverses the def chains in postorder traversal for
/// efficiency reasons: if an operation is already in `backwardSlice`, no
/// need to traverse its definitions again. Since useuse-def chains form a DAG,
/// this terminates.
/// need to traverse its definitions again. In the presence of use-def cycles
/// in a graph region, the traversal stops at the first operation that was
/// already visited (which is not added to the slice anymore).
///
/// Upon return to the root call, `backwardSlice` is filled with a
/// postorder list of defs. This happens to be a topological order, from the
Expand Down
63 changes: 47 additions & 16 deletions mlir/lib/Analysis/SliceAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
using namespace mlir;

static void
getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
getForwardSliceImpl(Operation *op, DenseSet<Operation *> &visited,
SetVector<Operation *> *forwardSlice,
const SliceOptions::TransitiveFilter &filter = nullptr) {
if (!op)
return;
Expand All @@ -41,20 +42,41 @@ getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
for (Region &region : op->getRegions())
for (Block &block : region)
for (Operation &blockOp : block)
if (forwardSlice->count(&blockOp) == 0)
getForwardSliceImpl(&blockOp, forwardSlice, filter);
for (Value result : op->getResults()) {
for (Operation *userOp : result.getUsers())
if (forwardSlice->count(userOp) == 0)
getForwardSliceImpl(userOp, forwardSlice, filter);
}
if (forwardSlice->count(&blockOp) == 0) {
// We don't have to check if the 'blockOp' is already visited because
// there cannot be a traversal path from this nested op to the parent
// and thus a cycle cannot be closed here. We still have to mark it
// as visited to stop before visiting this operation again if it is
// part of a cycle.
visited.insert(&blockOp);
getForwardSliceImpl(&blockOp, visited, forwardSlice, filter);
visited.erase(&blockOp);
}

for (Value result : op->getResults())
for (Operation *userOp : result.getUsers()) {
// A cycle can only occur within a basic block (not across regions or
// basic blocks) because the parent region must be a graph region, graph
// regions are restricted to always have 0 or 1 blocks, and there cannot
// be a def-use edge from a nested operation to an operation in an
// ancestor region. Therefore, we don't have to but may use the same
// 'visited' set across regions/blocks as long as we remove operations
// from the set again when the DFS traverses back from the leaf to the
// root.
if (forwardSlice->count(userOp) == 0 && visited.insert(userOp).second)
getForwardSliceImpl(userOp, visited, forwardSlice, filter);

visited.erase(userOp);
}

forwardSlice->insert(op);
}

void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
const ForwardSliceOptions &options) {
getForwardSliceImpl(op, forwardSlice, options.filter);
DenseSet<Operation *> visited;
visited.insert(op);
getForwardSliceImpl(op, visited, forwardSlice, options.filter);
if (!options.inclusive) {
// Don't insert the top level operation, we just queried on it and don't
// want it in the results.
Expand All @@ -70,8 +92,12 @@ void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,

void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
const SliceOptions &options) {
for (Operation *user : root.getUsers())
getForwardSliceImpl(user, forwardSlice, options.filter);
DenseSet<Operation *> visited;
for (Operation *user : root.getUsers()) {
visited.insert(user);
getForwardSliceImpl(user, visited, forwardSlice, options.filter);
visited.erase(user);
}

// Reverse to get back the actual topological order.
// std::reverse does not work out of the box on SetVector and I want an
Expand All @@ -80,7 +106,7 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
forwardSlice->insert(v.rbegin(), v.rend());
}

static void getBackwardSliceImpl(Operation *op,
static void getBackwardSliceImpl(Operation *op, DenseSet<Operation *> &visited,
SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
Expand All @@ -94,8 +120,11 @@ static void getBackwardSliceImpl(Operation *op,

auto processValue = [&](Value value) {
if (auto *definingOp = value.getDefiningOp()) {
if (backwardSlice->count(definingOp) == 0)
getBackwardSliceImpl(definingOp, backwardSlice, options);
if (backwardSlice->count(definingOp) == 0 &&
visited.insert(definingOp).second)
getBackwardSliceImpl(definingOp, visited, backwardSlice, options);

visited.erase(definingOp);
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
if (options.omitBlockArguments)
return;
Expand All @@ -108,7 +137,7 @@ static void getBackwardSliceImpl(Operation *op,
if (parentOp && backwardSlice->count(parentOp) == 0) {
assert(parentOp->getNumRegions() == 1 &&
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
getBackwardSliceImpl(parentOp, backwardSlice, options);
getBackwardSliceImpl(parentOp, visited, backwardSlice, options);
}
} else {
llvm_unreachable("No definingOp and not a block argument.");
Expand Down Expand Up @@ -138,7 +167,9 @@ static void getBackwardSliceImpl(Operation *op,
void mlir::getBackwardSlice(Operation *op,
SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
getBackwardSliceImpl(op, backwardSlice, options);
DenseSet<Operation *> visited;
visited.insert(op);
getBackwardSliceImpl(op, visited, backwardSlice, options);

if (!options.inclusive) {
// Don't insert the top level operation, we just queried on it and don't
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/Affine/slicing-utils.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,26 @@ func.func @slicing_test_multiple_return(%arg0: index) -> (index, index) {
%0:2 = "slicing-test-op"(%arg0, %arg0): (index, index) -> (index, index)
return %0#0, %0#1 : index, index
}

// -----

// FWD-LABEL: graph_region_with_cycle
// BWD-LABEL: graph_region_with_cycle
// FWDBWD-LABEL: graph_region_with_cycle
func.func @graph_region_with_cycle() {
test.isolated_graph_region {
// FWD: matched: [[V0:%.+]] = "slicing-test-op"([[V1:%.+]]) : (i1) -> i1 forward static slice:
// FWD: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1
// FWD: matched: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1 forward static slice:
// FWD: [[V0]] = "slicing-test-op"([[V1]]) : (i1) -> i1

// BWD: matched: [[V0:%.+]] = "slicing-test-op"([[V1:%.+]]) : (i1) -> i1 backward static slice:
// BWD: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1
// BWD: matched: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1 backward static slice:
// BWD: [[V0]] = "slicing-test-op"([[V1]]) : (i1) -> i1
%0 = "slicing-test-op"(%1) : (i1) -> i1
%1 = "slicing-test-op"(%0) : (i1) -> i1
}

return
}
Loading