Skip to content

Commit 9a16aed

Browse files
committed
Improve depth limiting approach
1 parent 5f940da commit 9a16aed

File tree

3 files changed

+28
-20
lines changed

3 files changed

+28
-20
lines changed

mlir/include/mlir/Query/Matcher/ExtraMatchers.h

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,37 +80,33 @@ bool BackwardSliceMatcher<Matcher>::matches(
8080
BackwardSliceOptions &options, int64_t maxDepth) {
8181
backwardSlice.clear();
8282
llvm::DenseMap<Operation *, int64_t> opDepths;
83-
// The starting point is the root op; therefore, we set its depth to 0.
83+
// Initializing the root op with a depth of 0
8484
opDepths[rootOp] = 0;
8585
options.filter = [&](Operation *subOp) {
86-
// If the subOp's depth exceeds maxDepth, we stop further slicing for this
87-
// branch.
88-
if (opDepths[subOp] > maxDepth)
86+
// If the subOp hasn't been recorded in opDepths, it is deeper than
87+
// maxDepth.
88+
if (!opDepths.contains(subOp))
8989
return false;
9090
// Examine subOp's operands to compute depths of their defining operations.
9191
for (auto operand : subOp->getOperands()) {
92+
int64_t newDepth = opDepths[subOp] + 1;
93+
// If the newDepth is greater than maxDepth, further computation can be
94+
// skipped.
95+
if (newDepth > maxDepth)
96+
continue;
97+
9298
if (auto definingOp = operand.getDefiningOp()) {
93-
// Set the defining operation's depth to one level greater than
94-
// subOp's depth.
95-
int64_t newDepth = opDepths[subOp] + 1;
96-
if (!opDepths.contains(definingOp)) {
99+
// Registers the minimum depth
100+
if (!opDepths.contains(definingOp) || newDepth < opDepths[definingOp])
97101
opDepths[definingOp] = newDepth;
98-
} else {
99-
opDepths[definingOp] = std::min(opDepths[definingOp], newDepth);
100-
}
101-
return !(opDepths[subOp] > maxDepth);
102102
} else {
103103
auto blockArgument = cast<BlockArgument>(operand);
104104
Operation *parentOp = blockArgument.getOwner()->getParentOp();
105105
if (!parentOp)
106106
continue;
107-
int64_t newDepth = opDepths[subOp] + 1;
108-
if (!opDepths.contains(parentOp)) {
107+
108+
if (!opDepths.contains(parentOp) || newDepth < opDepths[parentOp])
109109
opDepths[parentOp] = newDepth;
110-
} else {
111-
opDepths[parentOp] = std::min(opDepths[parentOp], newDepth);
112-
}
113-
return !(opDepths[parentOp] > maxDepth);
114110
}
115111
}
116112
return true;
@@ -119,7 +115,7 @@ bool BackwardSliceMatcher<Matcher>::matches(
119115
return true;
120116
}
121117

122-
// Matches transitive defs of a top-level operation up to N levels.
118+
/// Matches transitive defs of a top-level operation up to N levels.
123119
template <typename Matcher>
124120
inline BackwardSliceMatcher<Matcher>
125121
m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
@@ -130,6 +126,15 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
130126
omitUsesFromAbove);
131127
}
132128

129+
/// Matches all transitive defs of a top-level operation up to N levels
130+
template <typename Matcher>
131+
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
132+
int64_t maxDepth) {
133+
assert(maxDepth >= 0 && "maxDepth must be non-negative");
134+
return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth, true,
135+
false, false);
136+
}
137+
133138
} // namespace mlir::query::matcher
134139

135140
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H

mlir/test/mlir-query/complex-test.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-query %s -c "m getDefinitions(hasOpName(\"arith.addf\"),2,true,false,false)" | FileCheck %s
1+
// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s
22

33
#map = affine_map<(d0, d1) -> (d0, d1)>
44
func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {

mlir/tools/mlir-query/mlir-query.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ int main(int argc, char **argv) {
4343
matcherRegistry.registerMatcher(
4444
"getDefinitions",
4545
query::matcher::m_GetDefinitions<query::matcher::DynMatcher>);
46+
matcherRegistry.registerMatcher(
47+
"getAllDefinitions",
48+
query::matcher::m_GetAllDefinitions<query::matcher::DynMatcher>);
4649
matcherRegistry.registerMatcher("hasOpAttrName",
4750
static_cast<HasOpAttrName *>(m_Attr));
4851
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));

0 commit comments

Comments
 (0)