3131from tvm import autotvm , auto_scheduler
3232from tvm import relay
3333from tvm .driver .tvmc .registry import generate_registry_args , reconstruct_registry_entity
34- from tvm .ir .instrument import PassInstrument
34+ from tvm .ir .instrument import PassInstrument , PassTimingInstrument
3535from tvm .ir .memory_pools import WorkspaceMemoryPools
3636from tvm .target import Target
3737from tvm .relay .backend import Executor , Runtime
@@ -157,6 +157,11 @@ def add_compile_parser(subparsers, _, json_params):
157157 default = "default" ,
158158 help = "The output module name. Defaults to 'default'." ,
159159 )
160+ parser .add_argument (
161+ "--print-pass-times" ,
162+ action = "store_true" ,
163+ help = "print compilation time per pass" ,
164+ )
160165 for one_entry in json_params :
161166 parser .set_defaults (** one_entry )
162167
@@ -214,6 +219,7 @@ def drive_compile(args):
214219 workspace_pools = (
215220 workspace_pools_recombobulate (args , [workspace_pools_target ], extra_targets )
216221 ),
222+ print_pass_times = args .print_pass_times ,
217223 ** transform_args ,
218224 )
219225
@@ -240,6 +246,7 @@ def compile_model(
240246 use_vm : bool = False ,
241247 mod_name : Optional [str ] = "default" ,
242248 workspace_pools : Optional [WorkspaceMemoryPools ] = None ,
249+ print_pass_times : bool = False ,
243250 instruments : Optional [Sequence [PassInstrument ]] = None ,
244251 desired_layout : Optional [str ] = None ,
245252 desired_layout_ops : Optional [List [str ]] = None ,
@@ -301,6 +308,8 @@ def compile_model(
301308 workspace_pools: WorkspaceMemoryPools, optional
302309 Specification of WorkspacePoolInfo objects to be used as workspace memory in the
303310 compilation.
311+ print_pass_times: bool
312+ To enable printing a breakdown of compilation times by pass. Disabled by default.
304313 instruments: Optional[Sequence[PassInstrument]]
305314 The list of pass instrument implementations.
306315 desired_layout: str, optional
@@ -356,6 +365,10 @@ def compile_model(
356365 if codegen ["config_key" ] is not None :
357366 config [codegen ["config_key" ]] = codegen_from_cli ["opts" ]
358367
368+ if print_pass_times :
369+ timing_inst = PassTimingInstrument ()
370+ instruments = [timing_inst ] if instruments is None else [timing_inst ] + instruments
371+
359372 with tvm .transform .PassContext (
360373 opt_level = opt_level ,
361374 config = config ,
@@ -442,6 +455,11 @@ def compile_model(
442455 if dumps :
443456 save_dumps (package_path , dumps )
444457
458+ # Print compilation times per pass
459+ if print_pass_times :
460+ print ("Compilation time breakdown by pass:" )
461+ print (timing_inst .render ())
462+
445463 return TVMCPackage (package_path )
446464
447465
0 commit comments