33from __future__ import absolute_import as _abs
44
55import tvm
6- from . import graph_attr , graph_pass
6+ from . import graph_attr , graph_util
77from .. import graph as _graph
88from .. import runtime
99
10+ OPT_PASS_LEVEL = {
11+ "SimplifyBatchNormInference" : 2 ,
12+ "PrecomputePrune" : 2 ,
13+ "OpFusion" : 1
14+ }
15+
16+ # List of optimization pass and level when switch on
17+ class BuildConfig (object ):
18+ """Configuration scope to set a build config option.
19+
20+ Parameters
21+ ----------
22+ kwargs
23+ Keyword arguments of configurations to set.
24+ """
25+ current = None
26+ defaults = {
27+ "opt_level" : 2 ,
28+ }
29+ def __init__ (self , ** kwargs ):
30+ self ._old_scope = None
31+ for k , _ in kwargs .items ():
32+ if k not in BuildConfig .defaults :
33+ raise ValueError (
34+ "invalid argument %s, candidates are %s" % (k , BuildConfig .defaults .keys ()))
35+ self ._attr = kwargs
36+
37+ def __getattr__ (self , name ):
38+ if name not in self ._attr :
39+ return BuildConfig .defaults [name ]
40+ return self ._attr [name ]
41+
42+ def __enter__ (self ):
43+ # pylint: disable=protected-access
44+ self ._old_scope = BuildConfig .current
45+ attr = BuildConfig .current ._attr .copy ()
46+ attr .update (self ._attr )
47+ self ._attr = attr
48+ BuildConfig .current = self
49+ return self
50+
51+ def __exit__ (self , ptype , value , trace ):
52+ assert self ._old_scope
53+ BuildConfig .current = self ._old_scope
54+
55+
56+ BuildConfig .current = BuildConfig ()
57+
58+ def build_config (** kwargs ):
59+ """Configure the build behavior by setting config variables.
60+
61+ Parameters
62+ ----------
63+ opt_level: int, default=2
64+ Optimization level. See OPT_PASS_LEVEL for level of each pass.
65+
66+ Returns
67+ -------
68+ config: BuildConfig
69+ The build configuration
70+ """
71+ return BuildConfig (** kwargs )
72+
73+
1074@tvm .register_func ("nnvm.compiler.lower" )
1175def _lower (sch , inputs , func_name ):
1276 f = tvm .lower (sch , inputs , name = func_name )
@@ -19,23 +83,45 @@ def _build(funcs, target):
1983 return tvm .build (funcs , target = target )
2084
2185
22- def optimize (graph ):
23- """Perform graph optimization
86+ def _update_shape_dtype (shape , dtype , params ):
87+ """Update shape dtype given params information"""
88+ if not params :
89+ return shape , dtype
90+ shape = shape .copy ()
91+ shape .update ({k : v .shape for k , v in params .items ()})
92+ if isinstance (dtype , str ):
93+ for k , v in params .items ():
94+ if v .dtype != dtype :
95+ raise ValueError (
96+ "%s: dtype not expected %s vs %s" % (k , dtype , v .dtype ))
97+ else :
98+ dtype = dtype .copy ()
99+ dtype .update ({k : str (v .dtype ) for k , v in params .items ()})
100+ return shape , dtype
101+
102+
103+ def optimize (graph , shape , dtype = "float32" ):
104+ """Perform target and parameter invariant graph optimization.
24105
25106 Parameters
26107 ----------
27108 graph : Graph
28- The graph to be used in lowering .
109+ The graph to be used in optimized .
29110
30111 Returns
31112 -------
32113 graph : Graph
33- The optimized execution graph.
114+ The optimized graph.
34115 """
116+ # pylint: disable=unused-argument
117+ cfg = BuildConfig .current
118+ if cfg .opt_level >= OPT_PASS_LEVEL ["SimplifyBatchNormInference" ]:
119+ graph = graph_attr .set_shape_inputs (graph , shape )
120+ graph = graph .apply (["InferShape" , "SimplifyBatchNormInference" ])
35121 return graph
36122
37123
38- def build (graph , target , shape , dtype = "float32" ):
124+ def build (graph , target , shape , dtype = "float32" , params = None ):
39125 """Build graph into runtime library.
40126
41127 This is the final step of graph compilation.
@@ -54,27 +140,45 @@ def build(graph, target, shape, dtype="float32"):
54140 dtype : str or dict of str to str
55141 The input types to the graph
56142
143+ params : dict of str to NDArray
144+ Input parameetrs to the graph that do not change
145+ during inference time. Used for pre-compute
146+ folding optimization.
147+
57148 Returns
58149 -------
59150 graph : Graph
60151 The final execution graph.
61152
62153 libmod : tvm.Module
63154 The modue that comes with the execution graph
155+
156+ params : dict of str to NDArray
157+ The updated parameters of graph if params is passed.
158+ This can be different from the params passed in.
64159 """
65160 if not isinstance (target , str ):
66161 raise TypeError ("require target to be str" )
67162 if not isinstance (shape , dict ):
68163 raise TypeError ("require shape to be dict" )
69-
164+ cfg = BuildConfig . current
70165 graph = graph if isinstance (graph , _graph .Graph ) else _graph .create (graph )
166+ shape , dtype = _update_shape_dtype (shape , dtype , params )
167+ # Apply optimization
168+ graph = optimize (graph , shape , dtype )
169+ # Precompute prune
170+ if params and cfg .opt_level >= OPT_PASS_LEVEL ["PrecomputePrune" ]:
171+ graph , params = precompute_prune (graph , params )
172+ shape , dtype = _update_shape_dtype (shape , dtype , params )
173+ # Operator Fusion and generatiom
71174 graph = graph_attr .set_shape_inputs (graph , shape )
72175 graph = graph_attr .set_dtype_inputs (graph , dtype )
73176 graph ._set_json_attr ("target" , target , "str" )
177+ graph ._set_json_attr ("opt_level" , cfg .opt_level , "int" )
74178 graph = graph .apply ("InferShape" ).apply ("InferType" )
75179 graph = graph .apply ("GraphFusePartition" ).apply ("GraphFuse" )
76180 libmod = graph_attr ._move_out_module (graph , "module" )
77- return graph , libmod
181+ return graph , libmod , params
78182
79183
80184def _run_graph (graph , params ):
@@ -98,9 +202,9 @@ def _run_graph(graph, params):
98202 dtype = {k : v .dtype for k , v in params .items ()}
99203 target = "llvm"
100204 ctx = tvm .cpu (0 )
101- _ , oshape = graph_pass .infer_shape (graph , ** shape )
102- _ , odtype = graph_pass .infer_dtype (graph , ** dtype )
103- graph , libmod = build (graph , target , shape , dtype )
205+ _ , oshape = graph_util .infer_shape (graph , ** shape )
206+ _ , odtype = graph_util .infer_dtype (graph , ** dtype )
207+ graph , libmod , _ = build (graph , target , shape , dtype )
104208 m = runtime .create (graph , libmod , ctx )
105209 set_input , run , get_output = m ["set_input" ], m ["run" ], m ["get_output" ]
106210 for k , v in params .items ():
0 commit comments