Skip to content

Commit 9ff74fb

Browse files
authored
[TVMC] Add tvmc flag to print compilation time per pass (#15349)
Added a new flag `--print-pass-times` for tvmc compile to provide debugging information for tvmc users using `PassTimingInstrument`. Also added a test to check the printing of timing results.
1 parent 236eb31 commit 9ff74fb

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

python/tvm/driver/tvmc/compiler.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from tvm import autotvm, auto_scheduler
3232
from tvm import relay
3333
from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
34-
from tvm.ir.instrument import PassInstrument
34+
from tvm.ir.instrument import PassInstrument, PassTimingInstrument
3535
from tvm.ir.memory_pools import WorkspaceMemoryPools
3636
from tvm.target import Target
3737
from tvm.relay.backend import Executor, Runtime
@@ -157,6 +157,11 @@ def add_compile_parser(subparsers, _, json_params):
157157
default="default",
158158
help="The output module name. Defaults to 'default'.",
159159
)
160+
parser.add_argument(
161+
"--print-pass-times",
162+
action="store_true",
163+
help="print compilation time per pass",
164+
)
160165
for one_entry in json_params:
161166
parser.set_defaults(**one_entry)
162167

@@ -214,6 +219,7 @@ def drive_compile(args):
214219
workspace_pools=(
215220
workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
216221
),
222+
print_pass_times=args.print_pass_times,
217223
**transform_args,
218224
)
219225

@@ -240,6 +246,7 @@ def compile_model(
240246
use_vm: bool = False,
241247
mod_name: Optional[str] = "default",
242248
workspace_pools: Optional[WorkspaceMemoryPools] = None,
249+
print_pass_times: bool = False,
243250
instruments: Optional[Sequence[PassInstrument]] = None,
244251
desired_layout: Optional[str] = None,
245252
desired_layout_ops: Optional[List[str]] = None,
@@ -301,6 +308,8 @@ def compile_model(
301308
workspace_pools: WorkspaceMemoryPools, optional
302309
Specification of WorkspacePoolInfo objects to be used as workspace memory in the
303310
compilation.
311+
print_pass_times: bool
312+
To enable printing a breakdown of compilation times by pass. Disabled by default.
304313
instruments: Optional[Sequence[PassInstrument]]
305314
The list of pass instrument implementations.
306315
desired_layout: str, optional
@@ -356,6 +365,10 @@ def compile_model(
356365
if codegen["config_key"] is not None:
357366
config[codegen["config_key"]] = codegen_from_cli["opts"]
358367

368+
if print_pass_times:
369+
timing_inst = PassTimingInstrument()
370+
instruments = [timing_inst] if instruments is None else [timing_inst] + instruments
371+
359372
with tvm.transform.PassContext(
360373
opt_level=opt_level,
361374
config=config,
@@ -442,6 +455,11 @@ def compile_model(
442455
if dumps:
443456
save_dumps(package_path, dumps)
444457

458+
# Print compilation times per pass
459+
if print_pass_times:
460+
print("Compilation time breakdown by pass:")
461+
print(timing_inst.render())
462+
445463
return TVMCPackage(package_path)
446464

447465

tests/python/driver/tvmc/test_command_line.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,20 @@ def test_tvmc_logger_set_basicConfig(monkeypatch, tmpdir_factory, keras_simple):
272272
_main(compile_args)
273273

274274
mock_basicConfig.assert_called_with(stream=sys.stdout)
275+
276+
277+
def test_tvmc_print_pass_times(capsys, keras_simple, tmpdir_factory):
278+
pytest.importorskip("tensorflow")
279+
tmpdir = tmpdir_factory.mktemp("out")
280+
print_cmd = "--print-pass-times"
281+
282+
# Compile model
283+
module_file = os.path.join(tmpdir, "keras-tvm.tar")
284+
compile_cmd = f"tvmc compile --target 'llvm' {keras_simple} --output {module_file} {print_cmd}"
285+
compile_args = compile_cmd.split(" ")[1:]
286+
_main(compile_args)
287+
288+
# Check for timing results output
289+
captured_out = capsys.readouterr().out
290+
for exp_str in ("Compilation time breakdown by pass:", "sequential:", "us]"):
291+
assert exp_str in captured_out

0 commit comments

Comments
 (0)