@@ -45,7 +45,7 @@ inline void UpdateNodeVersion(Node *n) {
4545 CHECK (e.node ->is_variable ())
4646 << " Mutation target can only be Variable" ;
4747 // increase the version of the variable.
48- ++nnvm::get<VariableParam>(e.node ->attrs .parsed ).version ;
48+ e. version = ++nnvm::get<VariableParam>(e.node ->attrs .parsed ).version ;
4949 }
5050 }
5151 }
@@ -98,14 +98,20 @@ Symbol Symbol::Copy() const {
9898 std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
9999 // use DFSVisit to copy all the nodes
100100 DFSVisit (this ->outputs , [&old_new](const std::shared_ptr<Node>& node) {
101- old_new[node.get ()] = std::make_shared<Node>(*node);
101+ std::shared_ptr<Node> np = Node::Create ();
102+ np->op = node->op ;
103+ np->attrs = node->attrs ;
104+ old_new[node.get ()] = std::move (np);
102105 });
103106 // connect nodes of new graph
104107 for (const auto &kv : old_new) {
105108 for (const NodeEntry& e : kv.first ->inputs ) {
106109 Node *ptr = e.node .get ();
107110 kv.second ->inputs .emplace_back (NodeEntry{old_new[ptr], e.index , e.version });
108111 }
112+ for (const std::shared_ptr<Node>& p : kv.first ->control_deps ) {
113+ kv.second ->control_deps .emplace_back (old_new[p.get ()]);
114+ }
109115 }
110116 // set the head
111117 Symbol ret;
@@ -120,7 +126,7 @@ void Symbol::Print(std::ostream &os) const {
120126 os << " AtomicFunctor " << " Op:" << outputs[0 ].node ->op ->name << ' \n ' ;
121127 } else {
122128 // use DFSVisit to copy all the nodes
123- os << " Outputs:\n " ;
129+ os << " Symbol Outputs:\n " ;
124130 for (size_t i = 0 ; i < outputs.size (); ++i) {
125131 os << " \t output[" << i << " ]=" << outputs[i].node ->attrs .name
126132 << ' (' << outputs[i].index << " )\n " ;
@@ -129,7 +135,8 @@ void Symbol::Print(std::ostream &os) const {
129135 if (node->is_variable ()) {
130136 os << " Variable:" << node->attrs .name << ' \n ' ;
131137 } else {
132- os << " Name: " << node->attrs .name << " Op:" << node->op ->name << ' \n '
138+ os << " --------------------\n " ;
139+ os << " Op:" << node->op ->name << " , Name=" << node->attrs .name << ' \n '
133140 << " Inputs:\n " ;
134141 for (size_t i = 0 ; i < node->inputs .size (); ++i) {
135142 const NodeEntry& e = node->inputs [i];
@@ -141,9 +148,17 @@ void Symbol::Print(std::ostream &os) const {
141148 os << ' \n ' ;
142149 }
143150 }
144- os << " Attrs:\n " ;
145- for (auto &kv : node->attrs .dict ) {
146- os << ' \t ' << kv.first << ' =' << kv.second << ' \n ' ;
151+ if (!node->attrs .dict .empty ()) {
152+ os << " Attrs:\n " ;
153+ for (auto &kv : node->attrs .dict ) {
154+ os << ' \t ' << kv.first << ' =' << kv.second << ' \n ' ;
155+ }
156+ }
157+ if (node->control_deps .size () != 0 ) {
158+ os << " Control deps:\n " ;
159+ for (size_t i = 0 ; i < node->control_deps .size (); ++i) {
160+ os << " \t cdep[" << i << " ]=" << node->control_deps [i]->attrs .name << ' \n ' ;
161+ }
147162 }
148163 }
149164 });
@@ -203,8 +218,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
203218}
204219
205220// compositional logic
206- void Symbol::Compose (const std::vector< Symbol>& args,
207- const std::unordered_map<std::string, Symbol>& kwargs,
221+ void Symbol::Compose (const array_view< const Symbol* >& args,
222+ const std::unordered_map<std::string, const Symbol* >& kwargs,
208223 const std::string& name) {
209224 static auto & flist_inputs = Op::GetAttr<FListInputNames>(" FListInputNames" );
210225
@@ -213,11 +228,11 @@ void Symbol::Compose(const std::vector<Symbol>& args,
213228 CHECK (!outputs[0 ].node ->is_variable ()) << " Variable cannot be composed" ;
214229 // parameter check.
215230 for (size_t i = 0 ; i < args.size (); ++i) {
216- CHECK_EQ (args[i]. outputs .size (), 1 )
231+ CHECK_EQ (args[i]-> outputs .size (), 1 )
217232 << " Argument " << i << " is a tuple, single value is required" ;
218233 }
219234 for (const auto & kv : kwargs) {
220- CHECK_EQ (kv.second . outputs .size (), 1 )
235+ CHECK_EQ (kv.second -> outputs .size (), 1 )
221236 << " Keyword Argument " << kv.first << " is a tuple, single value is required" ;
222237 }
223238 // assign new name
@@ -234,7 +249,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
234249 << " Incorrect number of arguments, requires " << n_req
235250 << " , provided " << args.size ();
236251 for (size_t i = 0 ; i < args.size (); ++i) {
237- n->inputs [i] = args[i]. outputs [0 ];
252+ n->inputs [i] = args[i]-> outputs [0 ];
238253 }
239254 // switch to keyword argument matching
240255 if (args.size () != n_req) {
@@ -247,7 +262,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
247262 for (size_t i = args.size (); i < n_req; ++i) {
248263 auto it = kwargs.find (arg_names[i]);
249264 if (it != kwargs.end () && it->first == arg_names[i]) {
250- n->inputs [i] = it->second . outputs [0 ];
265+ n->inputs [i] = it->second -> outputs [0 ];
251266 ++nmatched;
252267 } else {
253268 n->inputs [i] = NodeEntry{
@@ -266,8 +281,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
266281 } else {
267282 CHECK_EQ (kwargs.size (), 0 ) << " Variable length function do not accept kwargs" ;
268283 n->inputs .reserve (args.size ());
269- for (const Symbol& s : args) {
270- n->inputs .push_back (s. outputs [0 ]);
284+ for (const Symbol* s : args) {
285+ n->inputs .push_back (s-> outputs [0 ]);
271286 }
272287 }
273288 UpdateNodeVersion (n);
@@ -283,13 +298,13 @@ void Symbol::Compose(const std::vector<Symbol>& args,
283298 (const std::shared_ptr<Node> &node) {
284299 if (node->is_variable ()) {
285300 if (arg_counter < args.size ()) {
286- replace_map[node.get ()] = &(args[arg_counter]. outputs [0 ]);
301+ replace_map[node.get ()] = &(args[arg_counter]-> outputs [0 ]);
287302 ++arg_counter;
288303 } else {
289304 // match kwargs
290305 auto kit = kwargs.find (node->attrs .name );
291306 if (kit != kwargs.end ()) {
292- replace_map[node.get ()] = &(kit->second . outputs [0 ]);
307+ replace_map[node.get ()] = &(kit->second -> outputs [0 ]);
293308 ++nmatched;
294309 }
295310 }
@@ -334,8 +349,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
334349 }
335350}
336351
337- Symbol Symbol::operator () (const std::vector< Symbol>& args,
338- const std::unordered_map<std::string, Symbol>& kwargs,
352+ Symbol Symbol::operator () (const array_view< const Symbol* >& args,
353+ const std::unordered_map<std::string, const Symbol* >& kwargs,
339354 const std::string& name) const {
340355 Symbol s = this ->Copy ();
341356 s.Compose (args, kwargs, name);
0 commit comments