Skip to content

Commit 159e739

Browse files
committed
[PASS] PrecomputePrune, add testcase (apache#14)
* [PASS] PrecomputePrune, add testcase * update comment
1 parent 3904d0b commit 159e739

File tree

14 files changed

+312
-50
lines changed

14 files changed

+312
-50
lines changed

nnvm/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ifneq ($(ADD_CFLAGS), NONE)
3030
endif
3131

3232
ifneq ($(ADD_LDFLAGS), NONE)
33-
LFFLAGS += $(ADD_LDFLAGS)
33+
LDFLAGS += $(ADD_LDFLAGS)
3434
endif
3535

3636
# plugin
@@ -46,6 +46,7 @@ ifeq ($(UNAME_S), Darwin)
4646
SHARED_LIBRARY_SUFFIX := dylib
4747
WHOLE_ARCH= -all_load
4848
NO_WHOLE_ARCH= -noall_load
49+
LDFLAGS += -undefined dynamic_lookup
4950
else
5051
SHARED_LIBRARY_SUFFIX := so
5152
WHOLE_ARCH= --whole-archive

nnvm/python/nnvm/compiler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tvm
55

66
from . import build_module
7-
from . build_module import build
7+
from . build_module import build, precompute_prune, _run_graph
88

99
from .. import symbol as _symbol
1010
from .. import graph as _graph

nnvm/python/nnvm/compiler/build_module.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from __future__ import absolute_import as _abs
44

55
import tvm
6-
from . import graph_attr
6+
from . import graph_attr, graph_pass
77
from .. import graph as _graph
8+
from .. import runtime
89

910
@tvm.register_func("nnvm.compiler.lower")
1011
def _lower(sch, inputs, func_name):
@@ -18,9 +19,6 @@ def _build(funcs, target):
1819
return tvm.build(funcs, target=target)
1920

2021

21-
_move_module = tvm.get_global_func("nnvm.compiler._move_module")
22-
23-
2422
def optimize(graph):
2523
"""Perform graph optimization
2624
@@ -70,10 +68,83 @@ def build(graph, target, shape, dtype="float32"):
7068
raise TypeError("require shape to be dict")
7169

7270
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
73-
graph = graph_attr.set_shape(graph, shape)
74-
graph = graph_attr.set_dtype(graph, dtype)
71+
graph = graph_attr.set_shape_inputs(graph, shape)
72+
graph = graph_attr.set_dtype_inputs(graph, dtype)
7573
graph._set_json_attr("target", target, "str")
7674
graph = graph.apply("InferShape").apply("InferType")
7775
graph = graph.apply("GraphFusePartition").apply("GraphFuse")
78-
libmod = _move_module(graph)
76+
libmod = graph_attr._move_out_module(graph, "module")
7977
return graph, libmod
78+
79+
80+
def _run_graph(graph, params):
81+
"""Helper utility to build and run and get outputs, only use cpu mode.
82+
83+
Parameters
84+
----------
85+
graph : Graph
86+
The graph to be executed.
87+
88+
params: dict of str to ndarray
89+
The parameter dictionary.
90+
91+
Returns
92+
-------
93+
out_dict: dict of str to tvm.NDArray
94+
The output dictionaries.
95+
"""
96+
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
97+
shape = {k : v.shape for k, v in params.items()}
98+
dtype = {k : v.dtype for k, v in params.items()}
99+
target = "llvm"
100+
ctx = tvm.cpu(0)
101+
_, oshape = graph_pass.infer_shape(graph, **shape)
102+
_, odtype = graph_pass.infer_dtype(graph, **dtype)
103+
graph, libmod = build(graph, target, shape, dtype)
104+
m = runtime.create(graph, libmod, ctx)
105+
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
106+
for k, v in params.items():
107+
set_input(k, tvm.nd.array(v))
108+
run()
109+
out_data = []
110+
for i, kv in enumerate(zip(oshape, odtype)):
111+
shape, dtype = kv
112+
arr = tvm.nd.empty(shape, dtype, ctx)
113+
get_output(i, arr)
114+
out_data.append(arr)
115+
return out_data
116+
117+
118+
def precompute_prune(graph, params):
119+
"""Precompute the part of graph that can be pre-computed.
120+
121+
This will create a new graph that only contains the ops
122+
that need to be computed depending on input as well as
123+
updated version of param dict that pre-computes some of
124+
intermediate results.
125+
126+
Parameters
127+
----------
128+
graph : Graph
129+
The input graph
130+
131+
params : dict of str -> tvm.NDArray
132+
The parameter dictionary of the graph
133+
134+
Returns
135+
-------
136+
pruned_graph : Graph
137+
The pruned graph
138+
139+
new_params : dict of str-> tvm.NDArray
140+
The updated dictionary of parameters.
141+
"""
142+
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
143+
graph._set_json_attr("param_name_list", list(params.keys()), "list_str")
144+
graph = graph.apply("PrecomputePrune")
145+
pre_graph = graph_attr._move_out_graph(graph, "precompute_graph")
146+
if not pre_graph.symbol.list_output_names():
147+
return graph, params
148+
out_names = pre_graph.json_attr("output_names")
149+
out_arrs = _run_graph(pre_graph, params)
150+
return graph, dict(zip(out_names, out_arrs))
Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
# pylint: disable=invalid-name
12
"""Utilities to access graph attributes"""
23
from __future__ import absolute_import as _abs
34

4-
def set_shape(g, shape):
5-
"""Set the shape of graph nodes in the graph attribute.
5+
import tvm
6+
7+
def set_shape_inputs(g, shape):
8+
"""Set the shape of input graph nodes in the graph attribute.
69
710
Parameters
811
----------
@@ -17,20 +20,24 @@ def set_shape(g, shape):
1720
g : Graph
1821
The updated graph with updated shape.
1922
"""
20-
index = g.index
21-
list_shape = [[]] * index.num_node_entries
22-
for k, v in shape.items():
23-
list_shape[index.entry_id(k)] = v
24-
g._set_json_attr("shape", list_shape, 'list_shape')
23+
list_shape = [
24+
shape.get(name, ()) for name in g.index.input_names]
25+
g._set_json_attr("shape_inputs", list_shape, 'list_shape')
2526
return g
2627

2728

28-
DTYPE_DICT = {
29+
DTYPE_TO_TCODE = {
30+
"default": -1,
2931
"float32": 0
3032
}
3133

32-
def set_dtype(g, dtype):
33-
"""Set the dtype of graph nodes
34+
TCODE_TO_DTYPE = {
35+
-1: None,
36+
0: "float32"
37+
}
38+
39+
def set_dtype_inputs(g, dtype):
40+
"""Set the dtype inputs of graph nodes
3441
3542
Parameters
3643
----------
@@ -45,12 +52,37 @@ def set_dtype(g, dtype):
4552
g : Graph
4653
The updated graph with updated dtype.
4754
"""
48-
index = g.index
4955
if isinstance(dtype, dict):
50-
list_dtype = [-1] * index.num_node_entries
51-
for k, v in dtype.items():
52-
list_dtype[index.entry_id(k)] = DTYPE_DICT[v]
56+
list_dtype = [
57+
DTYPE_TO_TCODE[dtype.get(name, "default")]
58+
for name in g.index.input_names]
5359
else:
54-
list_dtype = [DTYPE_DICT[dtype]] * index.num_node_entries
55-
g._set_json_attr("dtype", list_dtype, "list_int")
60+
list_dtype = [DTYPE_TO_TCODE[dtype]] * len(g.index.input_names)
61+
g._set_json_attr("dtype_inputs", list_dtype, "list_int")
62+
return g
63+
64+
65+
def set_layout_inputs(g, layout):
66+
"""Set the layout inputs of graph nodes
67+
68+
Parameters
69+
----------
70+
g : Graph
71+
The input graph
72+
73+
layout : dict of str to str or str
74+
The input layout
75+
76+
Returns
77+
-------
78+
g : Graph
79+
The updated graph with updated dtype.
80+
"""
81+
list_shape = [
82+
layout.get(name, "default") for name in g.index.input_names]
83+
g._set_json_attr("layout_inputs", list_shape, 'list_str')
5684
return g
85+
86+
87+
_move_out_module = tvm.get_global_func("nnvm.graph_attr._move_module")
88+
_move_out_graph = tvm.get_global_func("nnvm.graph_attr._move_graph")
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,62 @@
1+
# pylint: disable=invalid-name
12
"""Namespace of graph pass.
23
34
Principle:
45
- Graph in, graph out: always takes in graph as first argument and returns a graph
56
- Composable API: break graph transformation pass as segments of small transformations.
67
"""
78
from __future__ import absolute_import as _abs
9+
10+
from . import graph_attr
11+
12+
13+
def infer_shape(graph, **shape):
14+
"""Infer the shape given the shape of inputs.
15+
16+
Parameters
17+
----------
18+
graph : Graph
19+
The graph to perform shape inference from
20+
21+
Returns
22+
-------
23+
in_shape : list of tuple
24+
Shape of inputs
25+
26+
out_shape: list of tuple
27+
Shape of outputs
28+
"""
29+
graph = graph_attr.set_shape_inputs(graph, shape)
30+
graph = graph.apply("InferShape")
31+
shape = graph.json_attr("shape")
32+
index = graph.index
33+
input_shape = [shape[index.entry_id(x)] for x in index.input_names]
34+
output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
35+
return input_shape, output_shape
36+
37+
38+
def infer_dtype(graph, **dtype):
39+
"""Infer the type given the typeS of inputs.
40+
41+
Parameters
42+
----------
43+
graph : Graph
44+
The graph to perform type inference from
45+
46+
Returns
47+
-------
48+
in_dtype : list of tuple
49+
Dtype of inputs
50+
51+
out_dtype: list of tuple
52+
Dtype of outputs
53+
"""
54+
graph = graph_attr.set_dtype_inputs(graph, dtype)
55+
graph = graph.apply("InferType")
56+
dtype = graph.json_attr("dtype")
57+
index = graph.index
58+
input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
59+
for x in index.input_names]
60+
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
61+
for x in index.output_entries]
62+
return input_dtype, output_dtype

nnvm/python/nnvm/graph.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(self, graph):
2424
self.nodes = jgraph["nodes"]
2525
self.entry_ptr = jgraph["node_row_ptr"]
2626
self._name2nodeid = {n["name"]: i for i, n in enumerate(self.nodes)}
27+
self.input_names = graph.symbol.list_input_names()
28+
self.output_entries = jgraph["heads"]
2729

2830
@property
2931
def num_nodes(self):
@@ -66,6 +68,10 @@ def entry_id(self, key, value_index=0):
6668
index : int
6769
The entry index
6870
"""
71+
if isinstance(key, (list, tuple)):
72+
if len(key) != 3:
73+
raise ValueError("Expect entry index to be tuple of 3 elems")
74+
key, value_index, _ = key
6975
idx = self.node_id(key) if isinstance(key, str) else key
7076
assert value_index < self.entry_ptr[idx + 1]
7177
return self.entry_ptr[idx] + value_index

nnvm/python/nnvm/top/attr_dict.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ def get_int(self, key):
6868
"""
6969
return int(self[key])
7070

71+
def get_float(self, key):
72+
"""Get float from attr dict
73+
74+
Parameters
75+
----------
76+
key : str
77+
The attr key
78+
79+
Returns
80+
-------
81+
value : float
82+
The result value
83+
"""
84+
return float(self[key])
85+
7186
def get_bool(self, key):
7287
"""Get bool from attr dict
7388

nnvm/python/nnvm/top/tensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ def _schedule_broadcast(_, outs, target):
1717
tvm.schedule.AutoInlineInjective(s)
1818
return s
1919

20+
def _compute_binary_scalar(f):
21+
"""auxiliary function"""
22+
@tvm.tag_scope("ewise")
23+
def _compute(attrs, x):
24+
x = x[0]
25+
scalar = attrs.get_float("scalar")
26+
scalar = tvm.const(scalar, x.dtype)
27+
return tvm.compute(x.shape, lambda *i: f(x(*i), scalar))
28+
return _compute
29+
30+
2031
_fschedule_broadcast = tvm.convert(_schedule_broadcast)
2132

2233
# exp
@@ -25,6 +36,12 @@ def _schedule_broadcast(_, outs, target):
2536
reg.register_pattern("exp", OpPattern.ELEM_WISE)
2637
reg.register_schedule("exp", _fschedule_broadcast)
2738

39+
# add scalar
40+
reg.register_compute("__add_scalar__",
41+
_compute_binary_scalar(lambda x, y: x + y))
42+
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
43+
reg.register_schedule("__add_scalar__", _fschedule_broadcast)
44+
2845
# broadcast_add
2946
reg.register_compute("broadcast_add",
3047
lambda _, x: topi.broadcast_add(x[0], x[1]))

nnvm/src/compiler/packed_func_ext.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,19 @@ TVM_REGISTER_GLOBAL("nnvm._register_pattern")
104104
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
105105
op.set_attr<TOpPattern>("TOpPattern", args[1].operator int(), args[2]);
106106
});
107+
108+
TVM_REGISTER_GLOBAL("nnvm.graph_attr._move_module")
109+
.set_body([](TVMArgs args, TVMRetValue *rv) {
110+
const nnvm::Graph& g = args[0].AsExtension<Graph>();
111+
*rv = const_cast<nnvm::Graph*>(&g)->
112+
MoveCopyAttr<tvm::runtime::Module>(args[1]);
113+
});
114+
115+
TVM_REGISTER_GLOBAL("nnvm.graph_attr._move_graph")
116+
.set_body([](TVMArgs args, TVMRetValue *rv) {
117+
const nnvm::Graph& g = args[0].AsExtension<Graph>();
118+
*rv = const_cast<nnvm::Graph*>(&g)->
119+
MoveCopyAttr<nnvm::Graph>(args[1]);
120+
});
107121
} // namespace compiler
108122
} // namespace nnvm

nnvm/src/compiler/pass/graph_fuse.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -381,13 +381,5 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
381381

382382
NNVM_REGISTER_PASS(GraphFuse)
383383
.set_body(GraphFuse);
384-
385-
386-
TVM_REGISTER_GLOBAL("nnvm.compiler._move_module")
387-
.set_body([](TVMArgs args, TVMRetValue *rv) {
388-
const nnvm::Graph& g = args[0].AsExtension<Graph>();
389-
*rv = const_cast<nnvm::Graph*>(&g)->
390-
MoveCopyAttr<tvm::runtime::Module>("module");
391-
});
392384
} // namespace compiler
393385
} // namespace nnvm

0 commit comments

Comments
 (0)