@@ -80,37 +80,33 @@ bool BackwardSliceMatcher<Matcher>::matches(
80
80
BackwardSliceOptions &options, int64_t maxDepth) {
81
81
backwardSlice.clear ();
82
82
llvm::DenseMap<Operation *, int64_t > opDepths;
83
- // The starting point is the root op; therefore, we set its depth to 0.
83
+ // Initializing the root op with a depth of 0
84
84
opDepths[rootOp] = 0 ;
85
85
options.filter = [&](Operation *subOp) {
86
- // If the subOp's depth exceeds maxDepth, we stop further slicing for this
87
- // branch .
88
- if (opDepths[ subOp] > maxDepth )
86
+ // If the subOp hasn't been recorded in opDepths, it is deeper than
87
+ // maxDepth .
88
+ if (! opDepths. contains ( subOp) )
89
89
return false ;
90
90
// Examine subOp's operands to compute depths of their defining operations.
91
91
for (auto operand : subOp->getOperands ()) {
92
+ int64_t newDepth = opDepths[subOp] + 1 ;
93
+ // If the newDepth is greater than maxDepth, further computation can be
94
+ // skipped.
95
+ if (newDepth > maxDepth)
96
+ continue ;
97
+
92
98
if (auto definingOp = operand.getDefiningOp ()) {
93
- // Set the defining operation's depth to one level greater than
94
- // subOp's depth.
95
- int64_t newDepth = opDepths[subOp] + 1 ;
96
- if (!opDepths.contains (definingOp)) {
99
+ // Registers the minimum depth
100
+ if (!opDepths.contains (definingOp) || newDepth < opDepths[definingOp])
97
101
opDepths[definingOp] = newDepth;
98
- } else {
99
- opDepths[definingOp] = std::min (opDepths[definingOp], newDepth);
100
- }
101
- return !(opDepths[subOp] > maxDepth);
102
102
} else {
103
103
auto blockArgument = cast<BlockArgument>(operand);
104
104
Operation *parentOp = blockArgument.getOwner ()->getParentOp ();
105
105
if (!parentOp)
106
106
continue ;
107
- int64_t newDepth = opDepths[subOp] + 1 ;
108
- if (!opDepths.contains (parentOp)) {
107
+
108
+ if (!opDepths.contains (parentOp) || newDepth < opDepths[parentOp])
109
109
opDepths[parentOp] = newDepth;
110
- } else {
111
- opDepths[parentOp] = std::min (opDepths[parentOp], newDepth);
112
- }
113
- return !(opDepths[parentOp] > maxDepth);
114
110
}
115
111
}
116
112
return true ;
@@ -119,7 +115,7 @@ bool BackwardSliceMatcher<Matcher>::matches(
119
115
return true ;
120
116
}
121
117
122
- // Matches transitive defs of a top-level operation up to N levels.
118
+ // / Matches transitive defs of a top-level operation up to N levels.
123
119
template <typename Matcher>
124
120
inline BackwardSliceMatcher<Matcher>
125
121
m_GetDefinitions (Matcher innerMatcher, int64_t maxDepth, bool inclusive,
@@ -130,6 +126,15 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
130
126
omitUsesFromAbove);
131
127
}
132
128
129
+ // / Matches all transitive defs of a top-level operation up to N levels
130
+ template <typename Matcher>
131
+ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions (Matcher innerMatcher,
132
+ int64_t maxDepth) {
133
+ assert (maxDepth >= 0 && " maxDepth must be non-negative" );
134
+ return BackwardSliceMatcher<Matcher>(std::move (innerMatcher), maxDepth, true ,
135
+ false , false );
136
+ }
137
+
133
138
} // namespace mlir::query::matcher
134
139
135
140
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_EXTRAMATCHERS_H
0 commit comments