Skip to content

Commit 647267d

Browse files
committed
Allow libaray path to be configurable (apache#50)
* Allow libaray path to be configurable * Enable partial shape inference result to be passed via shape * fix python3 * disallow copy assign in index
1 parent bf8bd96 commit 647267d

File tree

5 files changed

+49
-9
lines changed

5 files changed

+49
-9
lines changed

nnvm/include/nnvm/graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ class IndexedGraph {
171171
inline const std::vector<NodeEntry>& outputs() const {
172172
return outputs_;
173173
}
174+
// disalllow copy assign
175+
IndexedGraph(const IndexedGraph&) = delete;
174176

175177
private:
176178
friend class Graph;

nnvm/python/nnvm/libinfo.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
# coding: utf-8
22
"""Information about nnvm."""
33
from __future__ import absolute_import
4+
import sys
45
import os
56
import platform
67

8+
if sys.version_info[0] == 3:
9+
import builtins as __builtin__
10+
else:
11+
import __builtin__
12+
713
def find_lib_path():
814
"""Find NNNet dynamic library files.
915
@@ -12,10 +18,19 @@ def find_lib_path():
1218
lib_path : list(string)
1319
List of all found path to the libraries
1420
"""
15-
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
16-
api_path = os.path.join(curr_path, '../../lib/')
17-
cmake_build_path = os.path.join(curr_path, '../../build/Release/')
18-
dll_path = [curr_path, api_path, cmake_build_path]
21+
if hasattr(__builtin__, "NNVM_BASE_PATH"):
22+
base_path = __builtin__.NNVM_BASE_PATH
23+
else:
24+
base_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
25+
26+
if hasattr(__builtin__, "NNVM_LIBRARY_NAME"):
27+
lib_name = __builtin__.NNVM_LIBRARY_NAME
28+
else:
29+
lib_name = "libnnvm_example"
30+
31+
api_path = os.path.join(base_path, '../../lib/')
32+
cmake_build_path = os.path.join(base_path, '../../build/Release/')
33+
dll_path = [base_path, api_path, cmake_build_path]
1934
if os.name == 'nt':
2035
vs_configuration = 'Release'
2136
if platform.architecture()[0] == '64bit':
@@ -27,9 +42,9 @@ def find_lib_path():
2742
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
2843
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
2944
if os.name == 'nt':
30-
dll_path = [os.path.join(p, 'libnnvm_example.dll') for p in dll_path]
45+
dll_path = [os.path.join(p, '%s.dll' % lib_name) for p in dll_path]
3146
else:
32-
dll_path = [os.path.join(p, 'libnnvm_example.so') for p in dll_path]
47+
dll_path = [os.path.join(p, '%s.so' % lib_name) for p in dll_path]
3348
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
3449
if len(lib_path) == 0:
3550
raise RuntimeError('Cannot find the files.\n' +

nnvm/python/nnvm/symbol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __rsub__(self, other):
5757
def __mul__(self, other):
5858
if isinstance(other, Symbol):
5959
return _internal.__mul_symbol__(self, other)
60-
if isinstance(other, Number):
60+
if isinstance(other, _Number):
6161
return _internal.__mul_scalar__(self, scalar=other)
6262
else:
6363
raise TypeError('type %s not supported' % str(type(other)))

nnvm/src/pass/infer_shape_type.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,16 @@ Graph InferAttr(Graph &&ret,
2323
using AttrVector = std::vector<AttrType>;
2424
const IndexedGraph& idx = ret.indexed_graph();
2525
static auto& finfer_shape =
26-
Op::GetAttr<FInferNodeEntryAttr<AttrType>>(infer_name);
26+
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
2727
static auto& backward_map =
2828
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
2929
// reshape shape vector
30-
AttrVector rshape(idx.num_node_entries(), default_val);
30+
AttrVector rshape;
31+
if (ret.attrs.count(attr_name) != 0) {
32+
rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
33+
} else {
34+
rshape.resize(idx.num_node_entries(), default_val);
35+
}
3136

3237
if (ret.attrs.count(input_name) != 0) {
3338
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
@@ -39,6 +44,7 @@ Graph InferAttr(Graph &&ret,
3944
// erase the provided arguments
4045
ret.attrs.erase(input_name);
4146
}
47+
4248
std::string shape_attr_key;
4349
if (ret.attrs.count(attr_key_name) != 0) {
4450
shape_attr_key = ret.GetAttr<std::string>(attr_key_name);

nnvm/tests/python/test_graph.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ def test_infer_shape():
5959
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
6060
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
6161

62+
def test_infer_shape_known_partial():
63+
x = sym.Variable('x', shape=(4, 2))
64+
y = sym.add(x, x, name='add1')
65+
y = sym.reshape(y, target=(2, 4), name="reshape1")
66+
g = graph.create(y)
67+
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
68+
shape = [[4, 2], [] , []]
69+
g._set_json_attr("shape", shape, 'list_shape')
70+
g = g.apply("InferShape")
71+
jnodes = jgraph['nodes']
72+
jnode_row_ptr = jgraph['node_row_ptr']
73+
nindex = {n['name']: i for i, n in enumerate(jnodes)}
74+
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
75+
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
76+
77+
6278
def test_infer_type():
6379
x = sym.Variable('x')
6480
y = sym.add(x, x, name='add1')
@@ -116,6 +132,7 @@ def test_plan_memory():
116132
test_graph_json_attr()
117133
test_json_pass()
118134
test_infer_shape()
135+
test_infer_shape_known_partial()
119136
test_infer_type()
120137
test_place_device()
121138
test_plan_memory()

0 commit comments

Comments
 (0)