Skip to content

Commit 9c383f6

Browse files
jroeschzhiics
authored andcommitted
[PassManager] Implement pass manager tracing API (#4782)
* Implement pass tracing API * Set is_before correctly * Add docs for trace function * Fix lint * Remove PDB * Ensure trace_func is set before calling * Fix conditional
1 parent d54036a commit 9c383f6

File tree

6 files changed

+94
-5
lines changed

6 files changed

+94
-5
lines changed

docs/dev/relay_pass_infra.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,26 @@ By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will
621621
dump out the module IR when ``FoldConstant`` is done. Users can plug in this
622622
pass after any pass they want to debug for viewing the optimization effect.
623623

624+
There is a more flexible debugging mechanism also exposed by the build configuration
625+
object. One can pass a tracing function which can be used to execute arbitrary code
626+
before and/or after each pass. A tracing function will receive a ``IRModule``, ``PassInfo``,
627+
and a boolean indicating whether you are executing before, or after a pass.
628+
An example is below.
629+
630+
.. code:: python
631+
632+
def print_ir(mod, info, is_before):
633+
"""Print the name of the pass, the IR, only before passes execute."""
634+
if is_before:
635+
print(f"Running pass: {}", info)
636+
print(mod)
637+
638+
with relay.build_config(opt_level=3, trace=print_ir):
639+
with tvm.target.create("llvm"):
640+
# Perform the optimizations.
641+
mod = seq(mod)
642+
643+
624644
For more pass infra related examples in Python and C++, please refer to
625645
`tests/python/relay/test_pass_manager.py`_ and
626646
`tests/cpp/relay_transform_sequential.cc`_, respectively.

include/tvm/ir/transform.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,17 @@
6565
namespace tvm {
6666
namespace transform {
6767

68+
// Forward declare for TraceFunc.
69+
class PassInfo;
70+
71+
/*! \brief A callback for tracing passes, useful for debugging and logging.
72+
*
73+
*/
74+
using TraceFunc =
75+
runtime::TypedPackedFunc<void(const IRModule& ir_module,
76+
const PassInfo& ctx,
77+
bool is_before)>;
78+
6879
/*!
6980
* \brief PassContextNode contains the information that a pass can rely on,
7081
* such as analysis results.
@@ -88,6 +99,8 @@ class PassContextNode : public Object {
8899
/*! \brief The list of disabled passes. */
89100
Array<PrimExpr> disabled_pass;
90101

102+
TraceFunc trace_func;
103+
91104
PassContextNode() = default;
92105

93106
void VisitAttrs(AttrVisitor* v) {
@@ -101,6 +114,7 @@ class PassContextNode : public Object {
101114
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
102115
};
103116

117+
104118
/*!
105119
* \brief PassContext that is used to configure the pass behavior.
106120
*
@@ -146,6 +160,14 @@ class PassContext : public ObjectRef {
146160
*/
147161
TVM_DLL static PassContext Current();
148162

163+
/*!
164+
* \brief Apply the tracing functions of the context to the module, with the info.
165+
* \param module The IRModule to trace.
166+
* \param info The pass information.
167+
* \param is_before Indicated whether the tracing is before or after a pass.
168+
*/
169+
TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const;
170+
149171
// accessor.
150172
using ContainerType = PassContextNode;
151173
class Internal;

python/tvm/relay/transform.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def __init__(self,
7878
opt_level=2,
7979
fallback_device=_nd.cpu(),
8080
required_pass=None,
81-
disabled_pass=None):
81+
disabled_pass=None,
82+
trace=None):
8283
if isinstance(fallback_device, str):
8384
fallback_device = _nd.context(fallback_device).device_type
8485
elif isinstance(fallback_device, TVMContext):
@@ -99,7 +100,7 @@ def __init__(self,
99100

100101
self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
101102
fallback_device, required,
102-
disabled)
103+
disabled, trace)
103104

104105
def __enter__(self):
105106
_transform.EnterPassContext(self)
@@ -117,7 +118,8 @@ def current():
117118
def build_config(opt_level=2,
118119
fallback_device=_nd.cpu(),
119120
required_pass=None,
120-
disabled_pass=None):
121+
disabled_pass=None,
122+
trace=None):
121123
"""Configure the build behavior by setting config variables.
122124
123125
Parameters
@@ -151,13 +153,16 @@ def build_config(opt_level=2,
151153
disabled_pass: set of str, optional
152154
Optimization passes to be disabled during optimization.
153155
156+
trace: Callable[[IRModule, PassInfo, bool], None]
157+
A tracing function for debugging or introspection.
158+
154159
Returns
155160
-------
156161
pass_context: PassContext
157162
The pass context for optimizations.
158163
"""
159164
return PassContext(opt_level, fallback_device, required_pass,
160-
disabled_pass)
165+
disabled_pass, trace)
161166

162167

163168
@register_relay_node

src/ir/transform.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ PassContext PassContext::Create() {
8484
return PassContext(make_object<PassContextNode>());
8585
}
8686

87+
void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
88+
auto pass_ctx_node = this->operator->();
89+
if (pass_ctx_node->trace_func != nullptr) {
90+
pass_ctx_node->trace_func(module, info, is_before);
91+
}
92+
}
93+
8794
class ModulePass;
8895

8996
/*!
@@ -231,8 +238,10 @@ IRModule ModulePassNode::operator()(const IRModule& mod,
231238
<< " with opt level: "
232239
<< pass_info->opt_level;
233240
CHECK(mod.defined());
241+
pass_ctx.Trace(mod, pass_info, true);
234242
IRModule updated_mod = pass_func(mod, pass_ctx);
235243
CHECK(updated_mod.defined());
244+
pass_ctx.Trace(updated_mod, pass_info, false);
236245
return updated_mod;
237246
}
238247

@@ -414,10 +423,12 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
414423
int fallback_device = args[1];
415424
tvm::Array<tvm::PrimExpr> required = args[2];
416425
tvm::Array<tvm::PrimExpr> disabled = args[3];
426+
TraceFunc trace_func = args[4];
417427
pctx->opt_level = opt_level;
418428
pctx->fallback_device = fallback_device;
419429
pctx->required_pass = std::move(required);
420430
pctx->disabled_pass = std::move(disabled);
431+
pctx->trace_func = std::move(trace_func);
421432
*ret = pctx;
422433
});
423434

src/relay/ir/transform.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
116116
<< pass_info->name
117117
<< " with opt level: "
118118
<< pass_info->opt_level;
119-
119+
pass_ctx.Trace(mod, pass_info, true);
120120
// Execute the pass function and return a new module.
121121
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
122122
std::vector<std::pair<GlobalVar, Function> > updates;
@@ -134,6 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
134134
for (const auto& pair : updates) {
135135
updated_mod->Add(pair.first, pair.second, true);
136136
}
137+
pass_ctx.Trace(updated_mod, pass_info, false);
137138
return updated_mod;
138139
}
139140

tests/python/relay/test_pass_manager.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,36 @@ def test_print_ir(capfd):
522522
assert "Dumping the module IR" in out
523523
assert "multiply" in out
524524

525+
__TRACE_COUNTER__ = 0
526+
527+
def _tracer(module, info, is_before):
528+
global __TRACE_COUNTER__
529+
if bool(is_before):
530+
__TRACE_COUNTER__ += 1
531+
532+
def test_print_debug_callback():
533+
global __TRACE_COUNTER__
534+
shape = (1, 2, 3)
535+
tp = relay.TensorType(shape, "float32")
536+
x = relay.var("x", tp)
537+
y = relay.add(x, x)
538+
y = relay.multiply(y, relay.const(2, "float32"))
539+
func = relay.Function([x], y)
540+
541+
seq = _transform.Sequential([
542+
relay.transform.InferType(),
543+
relay.transform.FoldConstant(),
544+
relay.transform.DeadCodeElimination()
545+
])
546+
547+
assert __TRACE_COUNTER__ == 0
548+
mod = relay.Module({"main": func})
549+
550+
with relay.build_config(opt_level=3, trace=_tracer):
551+
mod = seq(mod)
552+
553+
assert __TRACE_COUNTER__ == 4
554+
525555

526556
if __name__ == "__main__":
527557
pytest.main()

0 commit comments

Comments
 (0)