Skip to content

Commit b37e5c2

Browse files
committed
[RUNTIME][COMPILER] Formal compiler pipeline, runtime wrapper module (#21)
* [RUNTIME][COMPILER] Formal compiler pipeline, runtime wrapper module * more detailed comments
1 parent ddd23a8 commit b37e5c2

File tree

13 files changed

+340
-130
lines changed

13 files changed

+340
-130
lines changed

nnvm/python/nnvm/compiler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tvm
55

66
from . import build_module
7-
from . build_module import build, precompute_prune, _run_graph
7+
from . build_module import build, optimize, build_config
88

99
from .. import symbol as _symbol
1010
from .. import graph as _graph

nnvm/python/nnvm/compiler/build_module.py

Lines changed: 115 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,74 @@
33
from __future__ import absolute_import as _abs
44

55
import tvm
6-
from . import graph_attr, graph_pass
6+
from . import graph_attr, graph_util
77
from .. import graph as _graph
88
from .. import runtime
99

10+
OPT_PASS_LEVEL = {
11+
"SimplifyBatchNormInference": 2,
12+
"PrecomputePrune": 2,
13+
"OpFusion": 1
14+
}
15+
16+
# List of optimization pass and level when switch on
17+
class BuildConfig(object):
18+
"""Configuration scope to set a build config option.
19+
20+
Parameters
21+
----------
22+
kwargs
23+
Keyword arguments of configurations to set.
24+
"""
25+
current = None
26+
defaults = {
27+
"opt_level": 2,
28+
}
29+
def __init__(self, **kwargs):
30+
self._old_scope = None
31+
for k, _ in kwargs.items():
32+
if k not in BuildConfig.defaults:
33+
raise ValueError(
34+
"invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys()))
35+
self._attr = kwargs
36+
37+
def __getattr__(self, name):
38+
if name not in self._attr:
39+
return BuildConfig.defaults[name]
40+
return self._attr[name]
41+
42+
def __enter__(self):
43+
# pylint: disable=protected-access
44+
self._old_scope = BuildConfig.current
45+
attr = BuildConfig.current._attr.copy()
46+
attr.update(self._attr)
47+
self._attr = attr
48+
BuildConfig.current = self
49+
return self
50+
51+
def __exit__(self, ptype, value, trace):
52+
assert self._old_scope
53+
BuildConfig.current = self._old_scope
54+
55+
56+
BuildConfig.current = BuildConfig()
57+
58+
def build_config(**kwargs):
59+
"""Configure the build behavior by setting config variables.
60+
61+
Parameters
62+
----------
63+
opt_level: int, default=2
64+
Optimization level. See OPT_PASS_LEVEL for level of each pass.
65+
66+
Returns
67+
-------
68+
config: BuildConfig
69+
The build configuration
70+
"""
71+
return BuildConfig(**kwargs)
72+
73+
1074
@tvm.register_func("nnvm.compiler.lower")
1175
def _lower(sch, inputs, func_name):
1276
f = tvm.lower(sch, inputs, name=func_name)
@@ -19,23 +83,45 @@ def _build(funcs, target):
1983
return tvm.build(funcs, target=target)
2084

2185

22-
def optimize(graph):
23-
"""Perform graph optimization
86+
def _update_shape_dtype(shape, dtype, params):
87+
"""Update shape dtype given params information"""
88+
if not params:
89+
return shape, dtype
90+
shape = shape.copy()
91+
shape.update({k : v.shape for k, v in params.items()})
92+
if isinstance(dtype, str):
93+
for k, v in params.items():
94+
if v.dtype != dtype:
95+
raise ValueError(
96+
"%s: dtype not expected %s vs %s" % (k, dtype, v.dtype))
97+
else:
98+
dtype = dtype.copy()
99+
dtype.update({k : str(v.dtype) for k, v in params.items()})
100+
return shape, dtype
101+
102+
103+
def optimize(graph, shape, dtype="float32"):
104+
"""Perform target and parameter invariant graph optimization.
24105
25106
Parameters
26107
----------
27108
graph : Graph
28-
The graph to be used in lowering.
109+
The graph to be used in optimized.
29110
30111
Returns
31112
-------
32113
graph : Graph
33-
The optimized execution graph.
114+
The optimized graph.
34115
"""
116+
# pylint: disable=unused-argument
117+
cfg = BuildConfig.current
118+
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
119+
graph = graph_attr.set_shape_inputs(graph, shape)
120+
graph = graph.apply(["InferShape", "SimplifyBatchNormInference"])
35121
return graph
36122

37123

38-
def build(graph, target, shape, dtype="float32"):
124+
def build(graph, target, shape, dtype="float32", params=None):
39125
"""Build graph into runtime library.
40126
41127
This is the final step of graph compilation.
@@ -54,27 +140,45 @@ def build(graph, target, shape, dtype="float32"):
54140
dtype : str or dict of str to str
55141
The input types to the graph
56142
143+
params : dict of str to NDArray
144+
Input parameetrs to the graph that do not change
145+
during inference time. Used for pre-compute
146+
folding optimization.
147+
57148
Returns
58149
-------
59150
graph : Graph
60151
The final execution graph.
61152
62153
libmod : tvm.Module
63154
The modue that comes with the execution graph
155+
156+
params : dict of str to NDArray
157+
The updated parameters of graph if params is passed.
158+
This can be different from the params passed in.
64159
"""
65160
if not isinstance(target, str):
66161
raise TypeError("require target to be str")
67162
if not isinstance(shape, dict):
68163
raise TypeError("require shape to be dict")
69-
164+
cfg = BuildConfig.current
70165
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
166+
shape, dtype = _update_shape_dtype(shape, dtype, params)
167+
# Apply optimization
168+
graph = optimize(graph, shape, dtype)
169+
# Precompute prune
170+
if params and cfg.opt_level >= OPT_PASS_LEVEL["PrecomputePrune"]:
171+
graph, params = precompute_prune(graph, params)
172+
shape, dtype = _update_shape_dtype(shape, dtype, params)
173+
# Operator Fusion and generatiom
71174
graph = graph_attr.set_shape_inputs(graph, shape)
72175
graph = graph_attr.set_dtype_inputs(graph, dtype)
73176
graph._set_json_attr("target", target, "str")
177+
graph._set_json_attr("opt_level", cfg.opt_level, "int")
74178
graph = graph.apply("InferShape").apply("InferType")
75179
graph = graph.apply("GraphFusePartition").apply("GraphFuse")
76180
libmod = graph_attr._move_out_module(graph, "module")
77-
return graph, libmod
181+
return graph, libmod, params
78182

79183

80184
def _run_graph(graph, params):
@@ -98,9 +202,9 @@ def _run_graph(graph, params):
98202
dtype = {k : v.dtype for k, v in params.items()}
99203
target = "llvm"
100204
ctx = tvm.cpu(0)
101-
_, oshape = graph_pass.infer_shape(graph, **shape)
102-
_, odtype = graph_pass.infer_dtype(graph, **dtype)
103-
graph, libmod = build(graph, target, shape, dtype)
205+
_, oshape = graph_util.infer_shape(graph, **shape)
206+
_, odtype = graph_util.infer_dtype(graph, **dtype)
207+
graph, libmod, _ = build(graph, target, shape, dtype)
104208
m = runtime.create(graph, libmod, ctx)
105209
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
106210
for k, v in params.items():

nnvm/python/nnvm/compiler/graph_pass.py

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -6,81 +6,3 @@
66
- Composable API: break graph transformation pass as segments of small transformations.
77
"""
88
from __future__ import absolute_import as _abs
9-
10-
import tvm
11-
from . import graph_attr
12-
13-
14-
def infer_shape(graph, **shape):
15-
"""Infer the shape given the shape of inputs.
16-
17-
Parameters
18-
----------
19-
graph : Graph
20-
The graph to perform shape inference from
21-
22-
Returns
23-
-------
24-
in_shape : list of tuple
25-
Shape of inputs
26-
27-
out_shape: list of tuple
28-
Shape of outputs
29-
"""
30-
graph = graph_attr.set_shape_inputs(graph, shape)
31-
graph = graph.apply("InferShape")
32-
shape = graph.json_attr("shape")
33-
index = graph.index
34-
input_shape = [shape[index.entry_id(x)] for x in index.input_names]
35-
output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
36-
return input_shape, output_shape
37-
38-
39-
def infer_dtype(graph, **dtype):
40-
"""Infer the type given the typeS of inputs.
41-
42-
Parameters
43-
----------
44-
graph : Graph
45-
The graph to perform type inference from
46-
47-
Returns
48-
-------
49-
in_dtype : list of tuple
50-
Dtype of inputs
51-
52-
out_dtype: list of tuple
53-
Dtype of outputs
54-
"""
55-
graph = graph_attr.set_dtype_inputs(graph, dtype)
56-
graph = graph.apply("InferType")
57-
dtype = graph.json_attr("dtype")
58-
index = graph.index
59-
input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
60-
for x in index.input_names]
61-
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
62-
for x in index.output_entries]
63-
return input_dtype, output_dtype
64-
65-
66-
_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")
67-
68-
def check_graph_equal(grapha, graphb):
69-
"""Check if two graphs have equal structure.
70-
71-
Parameters
72-
----------
73-
grapha : Graph
74-
The first graph
75-
76-
graphb : Graph
77-
The second graph
78-
79-
Raises
80-
------
81-
ValueError
82-
ValueError is raised with error message when graph not equal
83-
"""
84-
err = _deep_compare(grapha, graphb)
85-
if err:
86-
raise ValueError("Graph compare error: " + err)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# pylint: disable=invalid-name
2+
"""Utility function to get information from graph."""
3+
from __future__ import absolute_import as _abs
4+
5+
import tvm
6+
from . import graph_attr
7+
8+
def infer_shape(graph, **shape):
9+
"""Infer the shape given the shape of inputs.
10+
11+
Parameters
12+
----------
13+
graph : Graph
14+
The graph to perform shape inference from
15+
16+
Returns
17+
-------
18+
in_shape : list of tuple
19+
Shape of inputs
20+
21+
out_shape: list of tuple
22+
Shape of outputs
23+
"""
24+
graph = graph_attr.set_shape_inputs(graph, shape)
25+
graph = graph.apply("InferShape")
26+
shape = graph.json_attr("shape")
27+
index = graph.index
28+
input_shape = [shape[index.entry_id(x)] for x in index.input_names]
29+
output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
30+
return input_shape, output_shape
31+
32+
33+
def infer_dtype(graph, **dtype):
34+
"""Infer the type given the typeS of inputs.
35+
36+
Parameters
37+
----------
38+
graph : Graph
39+
The graph to perform type inference from
40+
41+
Returns
42+
-------
43+
in_dtype : list of tuple
44+
Dtype of inputs
45+
46+
out_dtype: list of tuple
47+
Dtype of outputs
48+
"""
49+
graph = graph_attr.set_dtype_inputs(graph, dtype)
50+
graph = graph.apply("InferType")
51+
dtype = graph.json_attr("dtype")
52+
index = graph.index
53+
input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
54+
for x in index.input_names]
55+
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
56+
for x in index.output_entries]
57+
return input_dtype, output_dtype
58+
59+
60+
_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")
61+
62+
def check_graph_equal(grapha, graphb):
63+
"""Check if two graphs have equal structure.
64+
65+
Parameters
66+
----------
67+
grapha : Graph
68+
The first graph
69+
70+
graphb : Graph
71+
The second graph
72+
73+
Raises
74+
------
75+
ValueError
76+
ValueError is raised with error message when graph not equal
77+
"""
78+
err = _deep_compare(grapha, graphb)
79+
if err:
80+
raise ValueError("Graph compare error: " + err)

0 commit comments

Comments
 (0)