33from __future__ import absolute_import as _abs
44
55import tvm
6- from . import graph_attr
6+ from . import graph_attr , graph_pass
77from .. import graph as _graph
8+ from .. import runtime
89
910@tvm .register_func ("nnvm.compiler.lower" )
1011def _lower (sch , inputs , func_name ):
@@ -18,9 +19,6 @@ def _build(funcs, target):
1819 return tvm .build (funcs , target = target )
1920
2021
21- _move_module = tvm .get_global_func ("nnvm.compiler._move_module" )
22-
23-
2422def optimize (graph ):
2523 """Perform graph optimization
2624
@@ -70,10 +68,83 @@ def build(graph, target, shape, dtype="float32"):
7068 raise TypeError ("require shape to be dict" )
7169
7270 graph = graph if isinstance (graph , _graph .Graph ) else _graph .create (graph )
73- graph = graph_attr .set_shape (graph , shape )
74- graph = graph_attr .set_dtype (graph , dtype )
71+ graph = graph_attr .set_shape_inputs (graph , shape )
72+ graph = graph_attr .set_dtype_inputs (graph , dtype )
7573 graph ._set_json_attr ("target" , target , "str" )
7674 graph = graph .apply ("InferShape" ).apply ("InferType" )
7775 graph = graph .apply ("GraphFusePartition" ).apply ("GraphFuse" )
78- libmod = _move_module (graph )
76+ libmod = graph_attr . _move_out_module (graph , "module" )
7977 return graph , libmod
78+
79+
80+ def _run_graph (graph , params ):
81+ """Helper utility to build and run and get outputs, only use cpu mode.
82+
83+ Parameters
84+ ----------
85+ graph : Graph
86+ The graph to be executed.
87+
88+ params: dict of str to ndarray
89+ The parameter dictionary.
90+
91+ Returns
92+ -------
93+ out_dict: dict of str to tvm.NDArray
94+ The output dictionaries.
95+ """
96+ graph = graph if isinstance (graph , _graph .Graph ) else _graph .create (graph )
97+ shape = {k : v .shape for k , v in params .items ()}
98+ dtype = {k : v .dtype for k , v in params .items ()}
99+ target = "llvm"
100+ ctx = tvm .cpu (0 )
101+ _ , oshape = graph_pass .infer_shape (graph , ** shape )
102+ _ , odtype = graph_pass .infer_dtype (graph , ** dtype )
103+ graph , libmod = build (graph , target , shape , dtype )
104+ m = runtime .create (graph , libmod , ctx )
105+ set_input , run , get_output = m ["set_input" ], m ["run" ], m ["get_output" ]
106+ for k , v in params .items ():
107+ set_input (k , tvm .nd .array (v ))
108+ run ()
109+ out_data = []
110+ for i , kv in enumerate (zip (oshape , odtype )):
111+ shape , dtype = kv
112+ arr = tvm .nd .empty (shape , dtype , ctx )
113+ get_output (i , arr )
114+ out_data .append (arr )
115+ return out_data
116+
117+
118+ def precompute_prune (graph , params ):
119+ """Precompute the part of graph that can be pre-computed.
120+
121+ This will create a new graph that only contains the ops
122+ that need to be computed depending on input as well as
123+ updated version of param dict that pre-computes some of
124+ intermediate results.
125+
126+ Parameters
127+ ----------
128+ graph : Graph
129+ The input graph
130+
131+ params : dict of str -> tvm.NDArray
132+ The parameter dictionary of the graph
133+
134+ Returns
135+ -------
136+ pruned_graph : Graph
137+ The pruned graph
138+
139+ new_params : dict of str-> tvm.NDArray
140+ The updated dictionary of parameters.
141+ """
142+ graph = graph if isinstance (graph , _graph .Graph ) else _graph .create (graph )
143+ graph ._set_json_attr ("param_name_list" , list (params .keys ()), "list_str" )
144+ graph = graph .apply ("PrecomputePrune" )
145+ pre_graph = graph_attr ._move_out_graph (graph , "precompute_graph" )
146+ if not pre_graph .symbol .list_output_names ():
147+ return graph , params
148+ out_names = pre_graph .json_attr ("output_names" )
149+ out_arrs = _run_graph (pre_graph , params )
150+ return graph , dict (zip (out_names , out_arrs ))
0 commit comments