@@ -25,7 +25,6 @@ Graph PlaceDevice(Graph src) {
2525 const Op* copy_op = Op::Get (src.GetAttr <std::string>(" device_copy_op" ));
2626 auto & device_assign_map = src.GetAttr <DeviceAssignMap>(" device_assign_map" );
2727 const IndexedGraph& idx = src.indexed_graph ();
28-
2928 DeviceVector device;
3029 // copy on write semanatics
3130 if (src.attrs .count (" device" ) != 0 ) {
@@ -79,17 +78,27 @@ Graph PlaceDevice(Graph src) {
7978 src.attrs [" device" ] = std::make_shared<any>(std::move (device));
8079 return src;
8180 }
82-
8381 std::map<std::tuple<uint32_t , uint32_t , int >, NodePtr> copy_map;
8482 std::vector<NodePtr> new_node_map (idx.num_nodes (), nullptr );
8583 std::unordered_map<const Node*, int > new_device_map;
84+ static auto & fmutate_inputs = Op::GetAttr<FMutateInputs>(" FMutateInputs" );
8685
8786 // insert copy node
8887 for (uint32_t nid = 0 ; nid < idx.num_nodes (); ++nid) {
8988 int dev_id = device[nid];
9089 const auto & inode = idx[nid];
9190 // check if mutation is needed
9291 bool need_mutate = false ;
92+ if (!inode.source ->is_variable () && fmutate_inputs.count (inode.source ->op ())) {
93+ for (uint32_t index : fmutate_inputs[inode.source ->op ()](inode.source ->attrs )) {
94+ auto e = inode.inputs [index];
95+ if (new_node_map[e.node_id ] != nullptr || dev_id != device[e.node_id ]) {
96+ LOG (FATAL) << " mutable state cannot go across device"
97+ << " op=" << inode.source ->op ()->name
98+ << " input_state_index=" << index;
99+ }
100+ }
101+ }
93102 for (const IndexedGraph::NodeEntry& e : inode.inputs ) {
94103 if (new_node_map[e.node_id ] != nullptr || dev_id != device[e.node_id ]) {
95104 need_mutate = true ; break ;
@@ -102,6 +111,9 @@ Graph PlaceDevice(Graph src) {
102111 }
103112 }
104113 }
114+ if (inode.source ->is_variable ()) {
115+ CHECK (!need_mutate) << " consistency check" ;
116+ }
105117 if (need_mutate) {
106118 NodePtr new_node = Node::Create ();
107119 new_node->attrs = inode.source ->attrs ;
@@ -120,7 +132,15 @@ Graph PlaceDevice(Graph src) {
120132 os << inode.source ->inputs [i].node ->attrs .name << " _" << e.index <<" _copy" ;
121133 copy_node->attrs .op = copy_op;
122134 copy_node->attrs .name = os.str ();
123- copy_node->inputs .push_back (inode.source ->inputs [i]);
135+ if (new_node_map[e.node_id ] != nullptr ) {
136+ copy_node->inputs .emplace_back (
137+ NodeEntry{new_node_map[e.node_id ], e.index , 0 });
138+ } else {
139+ copy_node->inputs .push_back (inode.source ->inputs [i]);
140+ }
141+ if (copy_node->attrs .op ->attr_parser != nullptr ) {
142+ copy_node->attrs .op ->attr_parser (&(copy_node->attrs ));
143+ }
124144 copy_map[copy_key] = copy_node;
125145 new_device_map[copy_node.get ()] = dev_id;
126146 new_node->inputs .emplace_back (
@@ -130,7 +150,7 @@ Graph PlaceDevice(Graph src) {
130150 if (new_node_map[e.node_id ] != nullptr ) {
131151 new_node->inputs .emplace_back (
132152 NodeEntry{new_node_map[e.node_id ], e.index , 0 });
133- } else {
153+ } else {
134154 new_node->inputs .push_back (inode.source ->inputs [i]);
135155 }
136156 }
@@ -150,7 +170,6 @@ Graph PlaceDevice(Graph src) {
150170 new_device_map[inode.source ] = dev_id;
151171 }
152172 }
153-
154173 // make the new graph
155174 Graph ret;
156175 for (const NodeEntry& e : src.outputs ) {
@@ -163,10 +182,11 @@ Graph PlaceDevice(Graph src) {
163182 }
164183 DeviceVector new_device_vec (ret.indexed_graph ().num_nodes ());
165184 for (uint32_t nid = 0 ; nid < ret.indexed_graph ().num_nodes (); ++nid) {
166- if (new_device_map.count (ret.indexed_graph ()[nid].source ) == 0 ) {
167- LOG (INFO) << " canot find " << ret.indexed_graph ()[nid].source ->attrs .name ;
185+ auto source = ret.indexed_graph ()[nid].source ;
186+ if (new_device_map.count (source) == 0 ) {
187+ LOG (FATAL) << " canot find " << source;
168188 }
169- new_device_vec[nid] = new_device_map.at (ret. indexed_graph ()[nid]. source );
189+ new_device_vec[nid] = new_device_map.at (source);
170190 }
171191 ret.attrs [" device" ] = std::make_shared<any>(std::move (new_device_vec));
172192 return ret;
0 commit comments