Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tvm import autotvm, auto_scheduler
from tvm import relay
from tvm.driver.tvmc.registry import generate_registry_args, reconstruct_registry_entity
from tvm.ir.instrument import PassInstrument
from tvm.ir.instrument import PassInstrument, PassTimingInstrument
from tvm.ir.memory_pools import WorkspaceMemoryPools
from tvm.target import Target
from tvm.relay.backend import Executor, Runtime
Expand Down Expand Up @@ -157,6 +157,11 @@ def add_compile_parser(subparsers, _, json_params):
default="default",
help="The output module name. Defaults to 'default'.",
)
parser.add_argument(
"--print-pass-times",
action="store_true",
help="print compilation time per pass",
)
for one_entry in json_params:
parser.set_defaults(**one_entry)

Expand Down Expand Up @@ -214,6 +219,7 @@ def drive_compile(args):
workspace_pools=(
workspace_pools_recombobulate(args, [workspace_pools_target], extra_targets)
),
print_pass_times=args.print_pass_times,
**transform_args,
)

Expand All @@ -240,6 +246,7 @@ def compile_model(
use_vm: bool = False,
mod_name: Optional[str] = "default",
workspace_pools: Optional[WorkspaceMemoryPools] = None,
print_pass_times: bool = False,
instruments: Optional[Sequence[PassInstrument]] = None,
desired_layout: Optional[str] = None,
desired_layout_ops: Optional[List[str]] = None,
Expand Down Expand Up @@ -301,6 +308,8 @@ def compile_model(
workspace_pools: WorkspaceMemoryPools, optional
Specification of WorkspacePoolInfo objects to be used as workspace memory in the
compilation.
print_pass_times: bool
To enable printing a breakdown of compilation times by pass. Disabled by default.
instruments: Optional[Sequence[PassInstrument]]
The list of pass instrument implementations.
desired_layout: str, optional
Expand Down Expand Up @@ -356,6 +365,10 @@ def compile_model(
if codegen["config_key"] is not None:
config[codegen["config_key"]] = codegen_from_cli["opts"]

if print_pass_times:
timing_inst = PassTimingInstrument()
instruments = [timing_inst] if instruments is None else [timing_inst] + instruments

with tvm.transform.PassContext(
opt_level=opt_level,
config=config,
Expand Down Expand Up @@ -442,6 +455,11 @@ def compile_model(
if dumps:
save_dumps(package_path, dumps)

# Print compilation times per pass
if print_pass_times:
print("Compilation time breakdown by pass:")
print(timing_inst.render())

return TVMCPackage(package_path)


Expand Down
17 changes: 17 additions & 0 deletions tests/python/driver/tvmc/test_command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,20 @@ def test_tvmc_logger_set_basicConfig(monkeypatch, tmpdir_factory, keras_simple):
_main(compile_args)

mock_basicConfig.assert_called_with(stream=sys.stdout)


def test_tvmc_print_pass_times(capsys, keras_simple, tmpdir_factory):
pytest.importorskip("tensorflow")
tmpdir = tmpdir_factory.mktemp("out")
print_cmd = "--print-pass-times"

# Compile model
module_file = os.path.join(tmpdir, "keras-tvm.tar")
compile_cmd = f"tvmc compile --target 'llvm' {keras_simple} --output {module_file} {print_cmd}"
compile_args = compile_cmd.split(" ")[1:]
_main(compile_args)

# Check for timing results output
captured_out = capsys.readouterr().out
for exp_str in ("Compilation time breakdown by pass:", "sequential:", "us]"):
assert exp_str in captured_out