Skip to content

Commit d5e1cea

Browse files
committed
Change op function pointer to std::function, enable mutation (apache#6)
1 parent 88520e1 commit d5e1cea

File tree

6 files changed

+123
-30
lines changed

6 files changed

+123
-30
lines changed

nnvm/include/nnvm/node.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ struct NodeEntry {
2424
std::shared_ptr<Node> node;
2525
/*! \brief index of output from the source. */
2626
uint32_t index;
27+
/*!
28+
* \brief version of input Variable.
29+
* This field can only be nonzero when this->node is a Variable node.
30+
* version is increased by one each time a Variable get composed to a mutation Op.
31+
* This information can be helpful to decide order of operations when sequence of mutation happens.
32+
*/
33+
uint32_t version;
2734
};
2835

2936
/*!

nnvm/include/nnvm/op.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,13 @@ class Op {
101101
* \param attrs The attribute of the node
102102
* \return number of outputs.
103103
*/
104-
uint32_t (*get_num_outputs)(const NodeAttrs& attrs) = nullptr;
104+
std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;
105105
/*!
106106
* \brief get number of inputs given information about the node.
107107
* \param attrs The attribute of the node
108108
* \return number of inputs
109109
*/
110-
uint32_t (*get_num_inputs)(const NodeAttrs& attrs) = nullptr;
110+
std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;
111111
/*!
112112
* \brief Attribute parser to parse the NodeAttrs information.
113113
*
@@ -140,8 +140,7 @@ class Op {
140140
* }
141141
* \endcode
142142
*/
143-
void (*attr_parser)(NodeAttrs* attrs) = nullptr;
144-
143+
std::function<void(NodeAttrs* attrs)> attr_parser = nullptr;
145144
// function fields.
146145
/*!
147146
* \brief setter function during registration
@@ -161,7 +160,7 @@ class Op {
161160
* \param fn The function to be set.
162161
* \return reference to self.
163162
*/
164-
inline Op& set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*)
163+
inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
165164
/*!
166165
* \brief Set the num_outputs
167166
* \param n The number of outputs to be set.
@@ -173,13 +172,13 @@ class Op {
173172
* \param fn The function to be set.
174173
* \return reference to self.
175174
*/
176-
inline Op& set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*)
175+
inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
177176
/*!
178177
* \brief Set the attr_parser function.
179178
* \param fn The number of outputs to be set.
180179
* \return reference to self.
181180
*/
182-
inline Op& set_attr_parser(void (*fn)(NodeAttrs* attrs)); // NOLINT(*)
181+
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
183182
/*!
184183
* \brief Register additional attributes to operator.
185184
* \param attr_name The name of the attribute.
@@ -342,7 +341,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
342341
return *this;
343342
}
344343

345-
inline Op& Op::set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)) { // NOLINT(*)
344+
inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
346345
this->get_num_inputs = fn;
347346
return *this;
348347
}
@@ -352,12 +351,12 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
352351
return *this;
353352
}
354353

355-
inline Op& Op::set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)) { // NOLINT(*)
354+
inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
356355
this->get_num_outputs = fn;
357356
return *this;
358357
}
359358

360-
inline Op& Op::set_attr_parser(void (*fn)(NodeAttrs* attrs)) { // NOLINT(*)
359+
inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { // NOLINT(*)
361360
this->attr_parser = fn;
362361
return *this;
363362
}

nnvm/include/nnvm/op_attr_types.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
namespace nnvm {
1414

15-
// These types are optional attributes in each op
16-
// Some of them are needed for certain pass.
15+
// These types are optional attributes in each operator.
16+
// Each attribute can be required by some passes.
1717

1818
/*!
1919
* \brief Return list of input arguments names of each operator.
@@ -37,6 +37,16 @@ using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs&
3737
*/
3838
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
3939

40+
/*!
41+
* \brief Check whether operator will mutate k-th input.
42+
* \param index The input index
43+
* \return Whether this operator will mutate index-th input.
44+
*
45+
* \note Register under "FMutateInput", default return false
46+
* FMutateInputs enables mutation order handling correctly.
47+
*/
48+
using FMutateInput = std::function<bool (const NodeAttrs& attrs, uint32_t index)>;
49+
4050
} // namespace nnvm
4151

4252
#endif // NNVM_OP_ATTR_TYPES_H_

nnvm/src/core/symbolic.cc

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,43 @@ namespace symbol_constants {
1313
const char *kNamespaceSeparator = "_";
1414
} // namespace symbol_constants
1515

16+
// auxililary version attribute in variable.
17+
struct VariableParam {
18+
uint32_t version{0};
19+
};
20+
21+
std::shared_ptr<Node> CreateVariableNode(const std::string& name) {
22+
std::shared_ptr<Node> n = Node::Create();
23+
n->op = nullptr;
24+
n->attrs.name = name;
25+
n->attrs.parsed = VariableParam();
26+
return n;
27+
}
28+
29+
// scan over a node's input, update the version to latest
30+
// If the node's op mutates a certain input variable,
31+
// The version of that varaible will increase
32+
// version is used to implicitly order the mutation sequences
33+
inline void UpdateNodeVersion(Node *n) {
34+
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
35+
for (NodeEntry& e : n->inputs) {
36+
if (e.node->is_variable()) {
37+
e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version;
38+
}
39+
}
40+
if (fmutate_inputs.count(n->op) != 0) {
41+
FMutateInput fmutate = fmutate_inputs[n->op];
42+
for (uint32_t i = 0; i < n->inputs.size(); ++i) {
43+
if (fmutate(n->attrs, i)) {
44+
NodeEntry& e = n->inputs[i];
45+
CHECK(e.node->is_variable())
46+
<< "Mutation target can only be Variable";
47+
// increase the version of the variable.
48+
++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
49+
}
50+
}
51+
}
52+
}
1653

1754
inline std::string DefaultVarName(const std::string &op_name,
1855
const std::string &arg_name) {
@@ -67,13 +104,13 @@ Symbol Symbol::Copy() const {
67104
for (const auto &kv : old_new) {
68105
for (const NodeEntry& e : kv.first->inputs) {
69106
Node *ptr = e.node.get();
70-
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index});
107+
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
71108
}
72109
}
73110
// set the head
74111
Symbol ret;
75112
for (const NodeEntry &e : outputs) {
76-
ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index});
113+
ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version});
77114
}
78115
return ret;
79116
}
@@ -95,8 +132,14 @@ void Symbol::Print(std::ostream &os) const {
95132
os << "Name: " << node->attrs.name << " Op:" << node->op->name << '\n'
96133
<< "Inputs:\n";
97134
for (size_t i = 0; i < node->inputs.size(); ++i) {
98-
os << "\targ[" << i << "]=" << node->inputs[i].node->attrs.name
99-
<< '(' << node->inputs[i].index << ")\n";
135+
const NodeEntry& e = node->inputs[i];
136+
os << "\targ[" << i << "]=" << e.node->attrs.name
137+
<< '(' << e.index << ")";
138+
if (e.node->is_variable()) {
139+
os << " version=" << e.version << '\n';
140+
} else {
141+
os << '\n';
142+
}
100143
}
101144
os << "Attrs:\n";
102145
for (auto &kv : node->attrs.dict) {
@@ -163,6 +206,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
163206
void Symbol::Compose(const std::vector<Symbol>& args,
164207
const std::unordered_map<std::string, Symbol>& kwargs,
165208
const std::string& name) {
209+
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
210+
166211
CHECK_EQ(outputs.size(), 1)
167212
<< "Only composition of value function is supported currently";
168213
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
@@ -193,7 +238,6 @@ void Symbol::Compose(const std::vector<Symbol>& args,
193238
}
194239
// switch to keyword argument matching
195240
if (args.size() != n_req) {
196-
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
197241
FListInputNames fn = flist_inputs.get(n->op, nullptr);
198242
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
199243
if (arg_names.size() != n_req) {
@@ -206,8 +250,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
206250
n->inputs[i] = it->second.outputs[0];
207251
++nmatched;
208252
} else {
209-
n->inputs[i] = NodeEntry{Node::Create(), 0};
210-
n->inputs[i].node->attrs.name = DefaultVarName(name, arg_names[i]);
253+
n->inputs[i] = NodeEntry{
254+
CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0};
211255
}
212256
}
213257

@@ -226,6 +270,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
226270
n->inputs.push_back(s.outputs[0]);
227271
}
228272
}
273+
UpdateNodeVersion(n);
229274
} else {
230275
// general composition
231276
CHECK_EQ(args.size(), 0)
@@ -253,25 +298,32 @@ void Symbol::Compose(const std::vector<Symbol>& args,
253298
DFSVisit(this->outputs, find_replace_map);
254299

255300
if (nmatched == kwargs.size() && arg_counter < args.size()) {
301+
std::vector<Node*> update_nodes;
256302
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
257-
auto find_replace_plan = [&replace_map, &replace_plan]
303+
auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes]
258304
(const std::shared_ptr<Node> &node) {
259305
// visit all the childs, find possible replacement
306+
bool repl = false;
260307
for (size_t i = 0; i < node->inputs.size(); ++i) {
261308
NodeEntry *e = &(node->inputs[i]);
262309
if (e->node->is_variable()) {
263310
auto iter = replace_map.find(e->node.get());
264311
if (iter != replace_map.end()) {
265312
replace_plan.push_back(std::make_pair(e, iter->second));
313+
repl = true;
266314
}
267315
}
268316
}
317+
if (repl) update_nodes.push_back(node.get());
269318
};
270319
DFSVisit(this->outputs, find_replace_plan);
271320

272321
for (const auto& kv : replace_plan) {
273322
*(kv.first) = *(kv.second);
274323
}
324+
for (Node* n : update_nodes) {
325+
UpdateNodeVersion(n);
326+
}
275327
} else {
276328
std::vector<std::string> keys = GetKeys(kwargs);
277329
std::vector<std::string> arg_names = ListArguments();
@@ -303,9 +355,15 @@ Symbol Symbol::GetInternals() const {
303355
Symbol ret;
304356
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& node) {
305357
Node* n = node.get();
306-
uint32_t nout = n->num_outputs();
307-
for (uint32_t i = 0; i < nout; ++i) {
308-
ret.outputs.emplace_back(NodeEntry{node, i});
358+
if (n->is_variable()) {
359+
// grab version from variable.
360+
VariableParam& param = nnvm::get<VariableParam>(n->attrs.parsed);
361+
ret.outputs.emplace_back(NodeEntry{node, 0, param.version});
362+
} else {
363+
uint32_t nout = n->num_outputs();
364+
for (uint32_t i = 0; i < nout; ++i) {
365+
ret.outputs.emplace_back(NodeEntry{node, i, 0});
366+
}
309367
}
310368
});
311369
return ret;
@@ -325,7 +383,7 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
325383
}
326384
}
327385
if (node->op != nullptr && node->op->attr_parser != nullptr) {
328-
(*node->op->attr_parser)(&(node->attrs));
386+
node->op->attr_parser(&(node->attrs));
329387
}
330388
}
331389

@@ -366,9 +424,9 @@ Symbol Symbol::CreateFunctor(const Op* op,
366424
n->op = op;
367425
n->attrs.dict = std::move(attrs);
368426
if (n->op->attr_parser != nullptr) {
369-
(*n->op->attr_parser)(&(n->attrs));
427+
n->op->attr_parser(&(n->attrs));
370428
}
371-
s.outputs.emplace_back(NodeEntry{std::move(n), 0});
429+
s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0});
372430
return s;
373431
}
374432

@@ -382,10 +440,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
382440

383441
Symbol Symbol::CreateVariable(const std::string& name) {
384442
Symbol s;
385-
std::shared_ptr<Node> n = Node::Create();
386-
n->op = nullptr;
387-
n->attrs.name = name;
388-
s.outputs.emplace_back(NodeEntry{std::move(n), 0});
443+
s.outputs.emplace_back(NodeEntry{CreateVariableNode(name), 0, 0});
389444
return s;
390445
}
391446

nnvm/src/example/operator.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <utility>
77

88
using nnvm::FListInputNames;
9+
using nnvm::FMutateInput;
910
using nnvm::NodeAttrs;
1011

1112
NNVM_REGISTER_OP(add)
@@ -29,3 +30,10 @@ NNVM_REGISTER_OP(conv2d)
2930

3031
NNVM_REGISTER_OP(add)
3132
.attr<std::string>("nick_name", "plus");
33+
34+
NNVM_REGISTER_OP(assign)
35+
.set_num_inputs(2)
36+
.set_num_outputs(1)
37+
.attr<FMutateInput>("FMutateInput", [](const NodeAttrs& attrs, uint32_t index) {
38+
return index == 0;
39+
});

nnvm/tests/python/test_symbol.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ def test_default_input():
2424
except NNVMError:
2525
pass
2626

27+
def test_mutate_input():
28+
x = sym.Variable('x')
29+
y = sym.conv2d(data=x, name='conv')
30+
z = sym.assign(x, y)
31+
t = sym.add(z, x)
32+
33+
try:
34+
z = sym.assign(z, z)
35+
assert False
36+
except NNVMError:
37+
pass
38+
39+
2740
if __name__ == "__main__":
2841
test_default_input()
2942
test_compose()
43+
test_mutate_input()

0 commit comments

Comments
 (0)