@@ -21,6 +21,13 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
2121 return def;
2222}
2323
24+ inline bool IsMutate (const std::vector<uint32_t >& mutate_inputs, uint32_t i) {
25+ if (mutate_inputs.size () == 0 ) return false ;
26+ auto it = std::lower_bound (
27+ mutate_inputs.begin (), mutate_inputs.end (), i);
28+ return (it != mutate_inputs.end ()) && (*it == i);
29+ }
30+
2431Graph OrderMutation (const Graph& src) {
2532 std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
2633 DFSVisit (src.outputs , [&version_hist](const NodePtr& n) {
@@ -37,7 +44,13 @@ Graph OrderMutation(const Graph& src) {
3744 // start preparing for remapping the nodes.
3845 std::unordered_map<Node*, NodePtr> old_new;
3946 auto prepare = [&version_hist, &old_new] (const NodePtr& n) {
40- static auto & fmutate_inputs = Op::GetAttr<FMutateInput>(" FMutateInput" );
47+ static auto & fmutate_inputs = Op::GetAttr<FMutateInputs>(" FMutateInputs" );
48+ std::vector<uint32_t > mutate_inputs;
49+ if (!n->is_variable () && fmutate_inputs.count (n->op )) {
50+ mutate_inputs = fmutate_inputs[n->op ](n->attrs );
51+ }
52+ std::sort (mutate_inputs.begin (), mutate_inputs.end ());
53+
4154 bool need_repl = false ;
4255 for (size_t i = 0 ; i < n->inputs .size (); ++i) {
4356 const NodeEntry& e = n->inputs [i];
@@ -46,9 +59,7 @@ Graph OrderMutation(const Graph& src) {
4659 auto it = version_hist.find (e.node .get ());
4760 if (it != version_hist.end ()) {
4861 std::vector<NodeEntry>& vec = it->second ;
49- uint32_t is_mutate =
50- fmutate_inputs.count (n->op ) ? fmutate_inputs[n->op ](n->attrs , i) : 0 ;
51- vec.emplace_back (NodeEntry{n, is_mutate, e.version });
62+ vec.emplace_back (NodeEntry{n, IsMutate (mutate_inputs, i), e.version });
5263 }
5364 } else {
5465 if (old_new.count (e.node .get ()) != 0 ) need_repl = true ;
@@ -91,18 +102,21 @@ Graph OrderMutation(const Graph& src) {
91102 get_with_default (old_new, p.get (), p));
92103 }
93104 // add control deps
94- static auto & fmutate_inputs = Op::GetAttr<FMutateInput>(" FMutateInput" );
105+ static auto & fmutate_inputs = Op::GetAttr<FMutateInputs>(" FMutateInputs" );
106+ std::vector<uint32_t > mutate_inputs;
107+ if (fmutate_inputs.count (kv.first ->op )) {
108+ mutate_inputs = fmutate_inputs[kv.first ->op ](kv.first ->attrs );
109+ }
110+ std::sort (mutate_inputs.begin (), mutate_inputs.end ());
111+
95112 for (size_t i = 0 ; i < kv.first ->inputs .size (); ++i) {
96113 const NodeEntry& e = kv.first ->inputs [i];
97114 if (e.node ->is_variable () && version_hist.count (e.node .get ()) != 0 ) {
98- FMutateInput fmutate = fmutate_inputs.get (kv.first ->op , nullptr );
99- uint32_t is_mutate = (fmutate == nullptr ) ? 0 : fmutate (kv.first ->attrs , i);
100115 std::vector<NodeEntry>& vec = version_hist.at (e.node .get ());
101-
102116 auto it = std::lower_bound (vec.begin (), vec.end (),
103117 NodeEntry{nullptr , 1 , e.version },
104118 comparator);
105- if (is_mutate != 0 ) {
119+ if (IsMutate (mutate_inputs, i) ) {
106120 int read_dep = 0 ;
107121 while (it != vec.begin ()) {
108122 --it;
0 commit comments