Skip to content

Commit 5629330

Browse files
committed
Updates (#14)
* Remove outstanding cython functions * Add in operator overload * Enable JSON to save version
1 parent 20ac351 commit 5629330

File tree

5 files changed

+101
-35
lines changed

5 files changed

+101
-35
lines changed

nnvm/python/nnvm/cython/symbol.pyx

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@ include "./base.pyi"
1414

1515
cdef extern from "nnvm/c_api.h":
1616
const char* NNGetLastError();
17-
int NNSymbolCreateVariable(const char *name, SymbolHandle *out);
18-
int NNSymbolCreateGroup(nn_uint num_symbols,
19-
SymbolHandle *symbols,
20-
SymbolHandle *out);
2117
int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
2218
AtomicSymbolCreator **out_array);
2319
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
@@ -34,31 +30,10 @@ cdef extern from "nnvm/c_api.h":
3430
const char ***arg_descriptions,
3531
const char **return_type);
3632
int NNSymbolFree(SymbolHandle symbol);
37-
int NNSymbolPrint(SymbolHandle symbol, const char **out_str);
38-
int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
39-
int NNSymbolGetAttr(SymbolHandle symbol,
40-
const char* key,
41-
const char** out,
42-
int *success);
4333
int NNSymbolSetAttrs(SymbolHandle symbol,
4434
nn_uint num_param,
4535
const char** keys,
4636
const char** values);
47-
int NNSymbolListAttrs(SymbolHandle symbol,
48-
int recursive_option,
49-
nn_uint *out_size,
50-
const char*** out);
51-
int NNSymbolListArguments(SymbolHandle symbol,
52-
nn_uint *out_size,
53-
const char ***out_str_array);
54-
int NNSymbolListOutputs(SymbolHandle symbol,
55-
nn_uint *out_size,
56-
const char ***out_str_array);
57-
int NNSymbolGetInternals(SymbolHandle symbol,
58-
SymbolHandle *out);
59-
int NNSymbolGetOutput(SymbolHandle symbol,
60-
nn_uint index,
61-
SymbolHandle *out);
6237
int NNSymbolCompose(SymbolHandle sym,
6338
const char* name,
6439
nn_uint num_args,

nnvm/python/nnvm/symbol.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,71 @@ class Symbol(SymbolBase):
3030

3131
def __add__(self, other):
3232
if isinstance(other, Symbol):
33-
return _internal.__add__symbol__(self, other)
33+
return _internal.__add_symbol__(self, other)
3434
elif isinstance(other, _Number):
35-
return _internal.__add__scalar__(self, scalar=other)
35+
return _internal.__add_scalar__(self, scalar=other)
3636
else:
3737
raise TypeError("type %s not supported" % str(type(other)))
3838

39+
def __radd__(self, other):
40+
return self.__add__(other)
41+
42+
def __sub__(self, other):
43+
if isinstance(other, Symbol):
44+
return _internal.__sub_symbol__(self, other)
45+
if isinstance(other, Number):
46+
return _internal.__sub_scalar__(self, scalar=other)
47+
else:
48+
raise TypeError('type %s not supported' % str(type(other)))
49+
50+
def __rsub__(self, other):
51+
if isinstance(other, Number):
52+
return _internal.__rsub_scalar__(self, scalar=other)
53+
else:
54+
raise TypeError('type %s not supported' % str(type(other)))
55+
56+
def __mul__(self, other):
57+
if isinstance(other, Symbol):
58+
return _internal.__mul_symbol__(self, other)
59+
if isinstance(other, Number):
60+
return _internal.__mul_scalar__(self, scalar=other)
61+
else:
62+
raise TypeError('type %s not supported' % str(type(other)))
63+
64+
def __rmul__(self, other):
65+
return self.__mul__(other)
66+
67+
def __div__(self, other):
68+
if isinstance(other, Symbol):
69+
return _internal.__div_symbol__(self, other)
70+
if isinstance(other, Number):
71+
return _internal.__div_scalar__(self, scalar=other)
72+
else:
73+
raise TypeError('type %s not supported' % str(type(other)))
74+
75+
def __rdiv__(self, other):
76+
if isinstance(other, Number):
77+
return _internal.__rdiv_scalar__(self, scalar=other)
78+
else:
79+
raise TypeError('type %s not supported' % str(type(other)))
80+
81+
def __truediv__(self, other):
82+
return self.__div__(other)
83+
84+
def __rtruediv__(self, other):
85+
return self.__rdiv__(other)
86+
87+
def __pow__(self, other):
88+
if isinstance(other, Symbol):
89+
return _internal.__pow_symbol__(self, other)
90+
if isinstance(other, Number):
91+
return _internal.__pow_scalar__(self, scalar=other)
92+
else:
93+
raise TypeError('type %s not supported' % str(type(other)))
94+
95+
def __neg__(self):
96+
return self.__mul__(-1.0)
97+
3998
def __copy__(self):
4099
return self.__deepcopy__()
41100

nnvm/src/example/operator.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ using nnvm::NodeAttrs;
1111

1212
NNVM_REGISTER_OP(add)
1313
.describe("add two data together")
14-
.set_num_inputs(2)
15-
.attr("inplace_pair", std::make_pair(0, 0));
14+
.set_num_inputs(2);
1615

16+
NNVM_REGISTER_OP(__add_symbol__)
17+
.describe("Alias of add")
18+
.set_num_inputs(2);
1719

1820
NNVM_REGISTER_OP(exp)
1921
.describe("take exponmential")

nnvm/src/pass/saveload_json.cc

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,32 @@ namespace pass {
3030
// auxiliary node structure for serialization.
3131
struct JSONNode {
3232
// the node entry structure in serialized format
33-
typedef std::pair<uint32_t, uint32_t> Entry;
33+
struct Entry {
34+
uint32_t node_id;
35+
uint32_t index;
36+
uint32_t version;
37+
void Save(dmlc::JSONWriter *writer) const {
38+
writer->BeginArray();
39+
writer->WriteArrayItem(node_id);
40+
writer->WriteArrayItem(index);
41+
writer->WriteArrayItem(version);
42+
writer->EndArray();
43+
}
44+
void Load(dmlc::JSONReader *reader) {
45+
reader->BeginArray();
46+
CHECK(reader->NextArrayItem()) << "invalid json format";
47+
reader->Read(&node_id);
48+
CHECK(reader->NextArrayItem()) << "invalid json format";
49+
reader->Read(&index);
50+
if (reader->NextArrayItem()) {
51+
reader->Read(&version);
52+
CHECK(!reader->NextArrayItem()) << "invalid json format";
53+
} else {
54+
version = 0;
55+
}
56+
}
57+
};
58+
3459
// pointer to the graph node
3560
NodePtr node;
3661
// inputs
@@ -75,6 +100,10 @@ struct JSONNode {
75100
if (op_type_str != "null") {
76101
try {
77102
node->op = Op::Get(op_type_str);
103+
// rebuild attribute parser
104+
if (node->op->attr_parser != nullptr) {
105+
node->op->attr_parser(&(node->attrs));
106+
}
78107
} catch (const dmlc::Error &err) {
79108
std::ostringstream os;
80109
os << "Failed loading Op " << node->attrs.name
@@ -132,7 +161,7 @@ Graph LoadJSON(const Graph& src) {
132161
n.node->inputs.reserve(n.inputs.size());
133162
for (const JSONNode::Entry &e : n.inputs) {
134163
n.node->inputs.emplace_back(
135-
NodeEntry{jgraph.nodes[e.first].node, e.second});
164+
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
136165
}
137166
n.node->control_deps.reserve(n.control_deps.size());
138167
for (uint32_t nid : n.control_deps) {
@@ -150,7 +179,7 @@ Graph LoadJSON(const Graph& src) {
150179
ret.outputs.reserve(jgraph.heads.size());
151180
for (const JSONNode::Entry &e : jgraph.heads) {
152181
ret.outputs.emplace_back(
153-
NodeEntry{jgraph.nodes[e.first].node, e.second});
182+
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
154183
}
155184
return ret;
156185
}
@@ -170,7 +199,7 @@ Graph SaveJSON(const Graph& src) {
170199
jnode.inputs.reserve(n->inputs.size());
171200
for (const NodeEntry& e : n->inputs) {
172201
jnode.inputs.emplace_back(
173-
std::make_pair(node2index.at(e.node.get()), e.index));
202+
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
174203
}
175204
for (const NodePtr& c : n->control_deps) {
176205
jnode.control_deps.push_back(node2index.at(c.get()));
@@ -179,7 +208,8 @@ Graph SaveJSON(const Graph& src) {
179208
});
180209

181210
for (const NodeEntry& e : src.outputs) {
182-
jgraph.heads.push_back(std::make_pair(node2index.at(e.node.get()), e.index));
211+
jgraph.heads.push_back(
212+
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
183213
}
184214

185215
std::ostringstream os;

nnvm/tests/python/test_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_order_mutation_pass():
3333
assert nindex['assign'] in jnodes[nindex['add2']]['control_deps']
3434
assert nindex['conv'] in jnodes[nindex['assign']]['control_deps']
3535
assert nindex['add1'] in jnodes[nindex['assign']]['control_deps']
36-
36+
assert jnodes[nindex['assign']]['inputs'][0][2] == 1
3737

3838
if __name__ == "__main__":
3939
test_order_mutation_pass()

0 commit comments

Comments
 (0)