@@ -13,6 +13,43 @@ namespace symbol_constants {
13
13
const char *kNamespaceSeparator = " _" ;
14
14
} // namespace symbol_constants
15
15
16
+ // auxililary version attribute in variable.
17
+ struct VariableParam {
18
+ uint32_t version{0 };
19
+ };
20
+
21
+ std::shared_ptr<Node> CreateVariableNode (const std::string& name) {
22
+ std::shared_ptr<Node> n = Node::Create ();
23
+ n->op = nullptr ;
24
+ n->attrs .name = name;
25
+ n->attrs .parsed = VariableParam ();
26
+ return n;
27
+ }
28
+
29
+ // scan over a node's input, update the version to latest
30
+ // If the node's op mutates a certain input variable,
31
+ // The version of that varaible will increase
32
+ // version is used to implicitly order the mutation sequences
33
+ inline void UpdateNodeVersion (Node *n) {
34
+ static auto & fmutate_inputs = Op::GetAttr<FMutateInput>(" FMutateInput" );
35
+ for (NodeEntry& e : n->inputs ) {
36
+ if (e.node ->is_variable ()) {
37
+ e.version = nnvm::get<VariableParam>(e.node ->attrs .parsed ).version ;
38
+ }
39
+ }
40
+ if (fmutate_inputs.count (n->op ) != 0 ) {
41
+ FMutateInput fmutate = fmutate_inputs[n->op ];
42
+ for (uint32_t i = 0 ; i < n->inputs .size (); ++i) {
43
+ if (fmutate (n->attrs , i)) {
44
+ NodeEntry& e = n->inputs [i];
45
+ CHECK (e.node ->is_variable ())
46
+ << " Mutation target can only be Variable" ;
47
+ // increase the version of the variable.
48
+ ++nnvm::get<VariableParam>(e.node ->attrs .parsed ).version ;
49
+ }
50
+ }
51
+ }
52
+ }
16
53
17
54
inline std::string DefaultVarName (const std::string &op_name,
18
55
const std::string &arg_name) {
@@ -67,13 +104,13 @@ Symbol Symbol::Copy() const {
67
104
for (const auto &kv : old_new) {
68
105
for (const NodeEntry& e : kv.first ->inputs ) {
69
106
Node *ptr = e.node .get ();
70
- kv.second ->inputs .emplace_back (NodeEntry{old_new[ptr], e.index });
107
+ kv.second ->inputs .emplace_back (NodeEntry{old_new[ptr], e.index , e. version });
71
108
}
72
109
}
73
110
// set the head
74
111
Symbol ret;
75
112
for (const NodeEntry &e : outputs) {
76
- ret.outputs .emplace_back (NodeEntry{old_new[e.node .get ()], e.index });
113
+ ret.outputs .emplace_back (NodeEntry{old_new[e.node .get ()], e.index , e. version });
77
114
}
78
115
return ret;
79
116
}
@@ -95,8 +132,14 @@ void Symbol::Print(std::ostream &os) const {
95
132
os << " Name: " << node->attrs .name << " Op:" << node->op ->name << ' \n '
96
133
<< " Inputs:\n " ;
97
134
for (size_t i = 0 ; i < node->inputs .size (); ++i) {
98
- os << " \t arg[" << i << " ]=" << node->inputs [i].node ->attrs .name
99
- << ' (' << node->inputs [i].index << " )\n " ;
135
+ const NodeEntry& e = node->inputs [i];
136
+ os << " \t arg[" << i << " ]=" << e.node ->attrs .name
137
+ << ' (' << e.index << " )" ;
138
+ if (e.node ->is_variable ()) {
139
+ os << " version=" << e.version << ' \n ' ;
140
+ } else {
141
+ os << ' \n ' ;
142
+ }
100
143
}
101
144
os << " Attrs:\n " ;
102
145
for (auto &kv : node->attrs .dict ) {
@@ -163,6 +206,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
163
206
void Symbol::Compose (const std::vector<Symbol>& args,
164
207
const std::unordered_map<std::string, Symbol>& kwargs,
165
208
const std::string& name) {
209
+ static auto & flist_inputs = Op::GetAttr<FListInputNames>(" FListInputNames" );
210
+
166
211
CHECK_EQ (outputs.size (), 1 )
167
212
<< " Only composition of value function is supported currently" ;
168
213
CHECK (!outputs[0 ].node ->is_variable ()) << " Variable cannot be composed" ;
@@ -193,7 +238,6 @@ void Symbol::Compose(const std::vector<Symbol>& args,
193
238
}
194
239
// switch to keyword argument matching
195
240
if (args.size () != n_req) {
196
- static auto & flist_inputs = Op::GetAttr<FListInputNames>(" FListInputNames" );
197
241
FListInputNames fn = flist_inputs.get (n->op , nullptr );
198
242
auto arg_names = (fn == nullptr ) ? std::vector<std::string>{" data" } : fn (n->attrs );
199
243
if (arg_names.size () != n_req) {
@@ -206,8 +250,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
206
250
n->inputs [i] = it->second .outputs [0 ];
207
251
++nmatched;
208
252
} else {
209
- n->inputs [i] = NodeEntry{Node::Create (), 0 };
210
- n-> inputs [i]. node -> attrs . name = DefaultVarName (name, arg_names[i]);
253
+ n->inputs [i] = NodeEntry{
254
+ CreateVariableNode ( DefaultVarName (name, arg_names[i])), 0 , 0 } ;
211
255
}
212
256
}
213
257
@@ -226,6 +270,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
226
270
n->inputs .push_back (s.outputs [0 ]);
227
271
}
228
272
}
273
+ UpdateNodeVersion (n);
229
274
} else {
230
275
// general composition
231
276
CHECK_EQ (args.size (), 0 )
@@ -253,25 +298,32 @@ void Symbol::Compose(const std::vector<Symbol>& args,
253
298
DFSVisit (this ->outputs , find_replace_map);
254
299
255
300
if (nmatched == kwargs.size () && arg_counter < args.size ()) {
301
+ std::vector<Node*> update_nodes;
256
302
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
257
- auto find_replace_plan = [&replace_map, &replace_plan]
303
+ auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes ]
258
304
(const std::shared_ptr<Node> &node) {
259
305
// visit all the childs, find possible replacement
306
+ bool repl = false ;
260
307
for (size_t i = 0 ; i < node->inputs .size (); ++i) {
261
308
NodeEntry *e = &(node->inputs [i]);
262
309
if (e->node ->is_variable ()) {
263
310
auto iter = replace_map.find (e->node .get ());
264
311
if (iter != replace_map.end ()) {
265
312
replace_plan.push_back (std::make_pair (e, iter->second ));
313
+ repl = true ;
266
314
}
267
315
}
268
316
}
317
+ if (repl) update_nodes.push_back (node.get ());
269
318
};
270
319
DFSVisit (this ->outputs , find_replace_plan);
271
320
272
321
for (const auto & kv : replace_plan) {
273
322
*(kv.first ) = *(kv.second );
274
323
}
324
+ for (Node* n : update_nodes) {
325
+ UpdateNodeVersion (n);
326
+ }
275
327
} else {
276
328
std::vector<std::string> keys = GetKeys (kwargs);
277
329
std::vector<std::string> arg_names = ListArguments ();
@@ -303,9 +355,15 @@ Symbol Symbol::GetInternals() const {
303
355
Symbol ret;
304
356
DFSVisit (this ->outputs , [&ret](const std::shared_ptr<Node>& node) {
305
357
Node* n = node.get ();
306
- uint32_t nout = n->num_outputs ();
307
- for (uint32_t i = 0 ; i < nout; ++i) {
308
- ret.outputs .emplace_back (NodeEntry{node, i});
358
+ if (n->is_variable ()) {
359
+ // grab version from variable.
360
+ VariableParam& param = nnvm::get<VariableParam>(n->attrs .parsed );
361
+ ret.outputs .emplace_back (NodeEntry{node, 0 , param.version });
362
+ } else {
363
+ uint32_t nout = n->num_outputs ();
364
+ for (uint32_t i = 0 ; i < nout; ++i) {
365
+ ret.outputs .emplace_back (NodeEntry{node, i, 0 });
366
+ }
309
367
}
310
368
});
311
369
return ret;
@@ -325,7 +383,7 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
325
383
}
326
384
}
327
385
if (node->op != nullptr && node->op ->attr_parser != nullptr ) {
328
- (* node->op ->attr_parser ) (&(node->attrs ));
386
+ node->op ->attr_parser (&(node->attrs ));
329
387
}
330
388
}
331
389
@@ -366,9 +424,9 @@ Symbol Symbol::CreateFunctor(const Op* op,
366
424
n->op = op;
367
425
n->attrs .dict = std::move (attrs);
368
426
if (n->op ->attr_parser != nullptr ) {
369
- (* n->op ->attr_parser ) (&(n->attrs ));
427
+ n->op ->attr_parser (&(n->attrs ));
370
428
}
371
- s.outputs .emplace_back (NodeEntry{std::move (n), 0 });
429
+ s.outputs .emplace_back (NodeEntry{std::move (n), 0 , 0 });
372
430
return s;
373
431
}
374
432
@@ -382,10 +440,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
382
440
383
441
Symbol Symbol::CreateVariable (const std::string& name) {
384
442
Symbol s;
385
- std::shared_ptr<Node> n = Node::Create ();
386
- n->op = nullptr ;
387
- n->attrs .name = name;
388
- s.outputs .emplace_back (NodeEntry{std::move (n), 0 });
443
+ s.outputs .emplace_back (NodeEntry{CreateVariableNode (name), 0 , 0 });
389
444
return s;
390
445
}
391
446
0 commit comments