Skip to content

Commit f21b5ca

Browse files
committed
[DOCS] Add save_param_dict, readme (apache#42)
1 parent b7b0061 commit f21b5ca

File tree

12 files changed

+412
-49
lines changed

12 files changed

+412
-49
lines changed

nnvm/README.md

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,54 @@
33
[![Build Status](https://travis-ci.org/dmlc/nnvm.svg?branch=master)](https://travis-ci.org/dmlc/nnvm)
44
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE)
55

6-
NNVM is a reusable computational graph optimization and compilation stack for deep learning systems.
7-
NNVM provides modules to:
6+
NNVM is a reusable computational graph optimization and compilation stack for deep learning systems. It provides modules to:
87

98
- Represent deep learning workloads from front-end frameworks via a graph IR.
109
- Optimize computation graphs to improve performance.
1110
- Compile into executable modules and deploy to different hardware backends with minimum dependency.
1211

13-
NNVM is designed to add new frontend, operators and graph optimizations in a decentralized fashion without changing the core interface. NNVM is part of [TVM stack](https://github.com/dmlc/tvm), which provides an end to end IR compilation stack for deploying deep learning workloads into different hardware backends
12+
NNVM is designed to add new frontend, operators and graph optimizations in a decentralized fashion without changing the core interface. NNVM is part of [TVM stack](https://github.com/dmlc/tvm). NNVM compiler toolchain can target hardware backends supported by TVM.
13+
The compiled module can be deployed to server, mobile, embedded devices and browsers with minimum dependency, in languages including c++, python, javascript, java, objective-c.
14+
15+
The following code snippet demonstrates the general workflow of nnvm compiler toolchain.
16+
17+
```python
18+
import tvm
19+
from tvm.contrib import graph_runtime, rpc
20+
import nnvm.frontend
21+
import nnvm.compiler
22+
23+
# get model from frameworks
24+
# change xyz to supported framework name.
25+
graph, params = nnvm.frontend.from_xyz(...)
26+
27+
# optimize and compile the graph to get a deployable module
28+
# target can be "opencl", "llvm", "metal" or any target supported by tvm
29+
target = "cuda"
30+
graph, lib, params = nnvm.compiler.build(
31+
graph, target, shape={"data", data_shape}, params=params)
32+
33+
# deploy and run on gpu(0)
34+
module = graph_runtime.create(graph, lib, tvm.gpu(0))
35+
module.set_input(**params)
36+
output = tvm.nd.empty(out_shape, ctx=tvm.gpu(0))
37+
for data_array in dataset:
38+
module.set_input("data", data_array)
39+
module.run()
40+
module.get_output(0, output)
41+
42+
# deploy to remote mobile/rasp/browser with minimum tvm rpc runtime
43+
# useful for quick experiments on mobile devices
44+
remote = rpc.connect(remote_host, remote_port)
45+
lib.export_library("mylib.so")
46+
remote.upload("mylib.so")
47+
rlib = rpc.load_module("mylib.so")
48+
# run on remote device
49+
rmodule = graph_runtime.create(graph, rlib, remote.gpu(0))
50+
rmodule.set_input(**params)
51+
rmodule.run()
52+
```
1453

1554
## Links
1655
- [TinyFlow](https://github.com/tqchen/tinyflow) on how you can use NNVM to build a TensorFlow like API.
1756
- [Apache MXNet](http://mxnet.io/) uses NNVM as a backend.
18-

nnvm/docs/api/python/compiler.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ nnvm.compiler
77

88
.. autofunction:: nnvm.compiler.build_config
99

10+
.. autofunction:: nnvm.compiler.save_param_dict
11+
12+
.. autofunction:: nnvm.compiler.load_param_dict
13+
1014
.. autofunction:: nnvm.compiler.optimize
1115

1216
.. automodule:: nnvm.compiler.graph_util

nnvm/python/nnvm/compiler/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""NNVM compiler toolchain.
22
3-
User only need to use :any:`build` and :any:`build_config` to do the compilation.
3+
User only need to use :any:`build` and :any:`build_config` to do the compilation,
4+
and :any:`save_param_dict` to save the parameters into bytes.
45
The other APIs are for more advanced interaction with the compiler toolchain.
56
"""
67
from __future__ import absolute_import
@@ -10,6 +11,7 @@
1011
from . import build_module
1112
from . build_module import build, optimize, build_config
1213
from . compile_engine import engine, graph_key
14+
from . param_dict import save_param_dict, load_param_dict
1315

1416
from .. import symbol as _symbol
1517
from .. import graph as _graph

nnvm/python/nnvm/compiler/build_module.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .. import graph as _graph
1010

1111
OPT_PASS_LEVEL = {
12-
"SimplifyInference": 2,
12+
"SimplifyInference": 0,
1313
"PrecomputePrune": 2,
1414
"OpFusion": 1
1515
}
@@ -26,6 +26,7 @@ class BuildConfig(object):
2626
current = None
2727
defaults = {
2828
"opt_level": 2,
29+
"add_pass": None,
2930
}
3031
def __init__(self, **kwargs):
3132
self._old_scope = None
@@ -53,6 +54,23 @@ def __exit__(self, ptype, value, trace):
5354
assert self._old_scope
5455
BuildConfig.current = self._old_scope
5556

57+
def pass_enabled(self, pass_name):
58+
"""Get whether pass is enabled.
59+
60+
Parameters
61+
----------
62+
pass_name : str
63+
The optimization pass name
64+
65+
Returns
66+
-------
67+
enabled : bool
68+
Whether pass is enabled.
69+
"""
70+
if self.add_pass and pass_name in self.add_pass:
71+
return True
72+
return self.opt_level >= OPT_PASS_LEVEL[pass_name]
73+
5674

5775
BuildConfig.current = BuildConfig()
5876

@@ -64,6 +82,9 @@ def build_config(**kwargs):
6482
opt_level: int, default=2
6583
Optimization level. See OPT_PASS_LEVEL for level of each pass.
6684
85+
add_pass: set of str
86+
Optimization pass to be added regardless of optimization level.
87+
6788
Returns
6889
-------
6990
config: BuildConfig
@@ -120,7 +141,7 @@ def optimize(graph, shape, dtype="float32"):
120141
"""
121142
# pylint: disable=unused-argument
122143
cfg = BuildConfig.current
123-
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyInference"]:
144+
if cfg.pass_enabled("SimplifyInference"):
124145
graph = graph_attr.set_shape_inputs(graph, shape)
125146
graph = graph.apply(["InferShape", "SimplifyInference"])
126147
return graph
@@ -182,14 +203,17 @@ def build(graph, target, shape, dtype="float32", params=None):
182203
# Apply optimization
183204
graph = optimize(graph, shape, dtype)
184205
# Precompute prune
185-
if params and cfg.opt_level >= OPT_PASS_LEVEL["PrecomputePrune"]:
206+
if params and cfg.pass_enabled("PrecomputePrune"):
186207
graph, params = precompute_prune(graph, params)
187208
shape, dtype = _update_shape_dtype(shape, dtype, params)
188209
# Operator Fusion and generatiom
189210
graph = graph_attr.set_shape_inputs(graph, shape)
190211
graph = graph_attr.set_dtype_inputs(graph, dtype)
191212
graph._set_json_attr("target", target, "str")
192-
graph._set_json_attr("opt_level", cfg.opt_level, "int")
213+
if cfg.pass_enabled("OpFusion"):
214+
graph._set_json_attr("opt_level", 1, "int")
215+
else:
216+
graph._set_json_attr("opt_level", 0, "int")
193217
graph = graph.apply("InferShape").apply("InferType")
194218
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
195219
libmod = graph_attr._move_out_module(graph, "module")
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""Helper utility to save parameter dict"""
2+
import tvm
3+
4+
_save_param_dict = tvm.get_global_func("nnvm.compiler._save_param_dict")
5+
_load_param_dict = tvm.get_global_func("nnvm.compiler._load_param_dict")
6+
7+
def save_param_dict(params):
8+
"""Save parameter dictionary to binary bytes.
9+
10+
The result binary bytes can be loaded by the
11+
GraphModule with API "load_params".
12+
13+
Parameters
14+
----------
15+
params : dict of str to NDArray
16+
The parameter dictionary.
17+
18+
Returns
19+
-------
20+
param_bytes: bytearray
21+
Serialized parameters.
22+
23+
Examples
24+
--------
25+
.. code-block:: python
26+
27+
# compile and save the modules to file.
28+
graph, lib, params = nnvm.compiler.build(
29+
graph, target, shape={"data", data_shape}, params=params)
30+
module = graph_runtime.create(graph, lib, tvm.gpu(0))
31+
# save the parameters as byte array
32+
param_bytes = nnvm.compiler.save_param_dict(params)
33+
# We can serialize the param_bytes and load it back later.
34+
# Pass in byte array to module to directly set parameters
35+
module["load_params"](param_bytes)
36+
"""
37+
args = []
38+
for k, v in params.items():
39+
args.append(k)
40+
args.append(tvm.nd.array(v))
41+
return _save_param_dict(*args)
42+
43+
44+
def load_param_dict(param_bytes):
45+
"""Load parameter dictionary to binary bytes.
46+
47+
Parameters
48+
----------
49+
param_bytes: bytearray
50+
Serialized parameters.
51+
52+
Returns
53+
-------
54+
params : dict of str to NDArray
55+
The parameter dictionary.
56+
"""
57+
if isinstance(param_bytes, (bytes, str)):
58+
param_bytes = bytearray(param_bytes)
59+
load_mod = _load_param_dict(param_bytes)
60+
size = load_mod(0)
61+
param_dict = {}
62+
for i in range(size):
63+
key = load_mod(1, i)
64+
dltensor_handle = load_mod(2, i)
65+
param_dict[key] = tvm.nd.NDArray(dltensor_handle, False)
66+
return param_dict

nnvm/python/nnvm/graph.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ._base import c_array, c_str, nn_uint, py_str, string_types
1313
from ._base import GraphHandle, SymbolHandle
1414
from ._base import check_call
15-
from .symbol import Symbol, Group as _Group
15+
from .symbol import Variable, Symbol, Group as _Group
1616

1717
class GraphIndex(object):
1818
"""Index for quickly accessing graph attributes.
@@ -174,9 +174,19 @@ def symbol(self):
174174
check_call(_LIB.NNGraphGetSymbol(self.handle, ctypes.byref(shandle)))
175175
return Symbol(shandle)
176176

177+
def json(self):
178+
"""Get JSON representation of the graph
179+
180+
Returns
181+
-------
182+
json : str
183+
JSON representation of the graph
184+
"""
185+
return self.apply("SaveJSON").json_attr("json")
186+
177187
def _tvm_graph_json(self):
178188
"""Get TVM graph json"""
179-
return self.apply("SaveJSON").json_attr("json")
189+
return self.json()
180190

181191
@property
182192
def index(self):
@@ -225,6 +235,24 @@ def apply(self, passes):
225235
return Graph(ghandle)
226236

227237

238+
def load_json(json_str):
239+
"""Create a new graph by loading from json
240+
241+
Parameters
242+
----------
243+
json_str : str
244+
The json string
245+
246+
Returns
247+
-------
248+
graph : Graph
249+
The loaded graph
250+
"""
251+
ret = create(Variable("x"))
252+
ret._set_json_attr("json", json_str)
253+
return ret.apply("LoadJSON")
254+
255+
228256
def create(symbol):
229257
"""Create a new graph from symbol.
230258

nnvm/src/compiler/graph_fuse.cc

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,10 @@
1515
#include <tvm/lowered_func.h>
1616
#include <dmlc/parameter.h>
1717
#include "./compile_engine.h"
18-
#include "../../tvm/src/runtime/graph/graph_runtime.h"
18+
#include "./graph_runtime.h"
1919

2020
namespace nnvm {
2121
namespace compiler {
22-
23-
24-
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
25-
std::string func_name;
26-
uint32_t num_inputs;
27-
uint32_t num_outputs;
28-
uint32_t flatten_data;
29-
30-
DMLC_DECLARE_PARAMETER(TVMOpParam) {
31-
DMLC_DECLARE_FIELD(func_name);
32-
DMLC_DECLARE_FIELD(num_inputs).set_default(1);
33-
DMLC_DECLARE_FIELD(num_outputs).set_default(1);
34-
DMLC_DECLARE_FIELD(flatten_data).set_default(0);
35-
}
36-
};
37-
38-
DMLC_REGISTER_PARAMETER(TVMOpParam);
39-
40-
// parser
41-
inline void TVMOpParamParser(nnvm::NodeAttrs* attrs) {
42-
TVMOpParam param;
43-
param.Init(attrs->dict);
44-
attrs->parsed = std::move(param);
45-
}
46-
47-
NNVM_REGISTER_OP(tvm_op)
48-
.set_attr_parser(TVMOpParamParser)
49-
.set_num_inputs([](const NodeAttrs& attrs) {
50-
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
51-
return param.num_inputs;
52-
})
53-
.set_num_outputs([](const NodeAttrs& attrs) {
54-
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
55-
return param.num_outputs;
56-
});
57-
5822
using namespace tvm;
5923

6024
// The single fuse rule.

0 commit comments

Comments
 (0)