@@ -202,3 +202,57 @@ def _split_into_nstages_equal_size(
202
202
return gm
203
203
204
204
return _split_into_nstages_equal_size
205
+
206
+
207
+ """
208
+ Create a Callable that splits a model into a given number of stages, based on the computation graph, while
209
+ trying to minimize the communication between the stages and to balance the computation
210
+ Input:
211
+ nstages: the number of stages to split the module into
212
+ Output:
213
+ a Callable that transforms an input `fx.GraphModule` into an output `fx.GraphModule` that has `pipe_split` inserted
214
+ between `nstages` stages
215
+ """
216
+
217
+
218
+ def split_by_graph (nstages : int ) -> Callable [[fx .GraphModule ], fx .GraphModule ]:
219
+ def _split_by_graph (
220
+ gm : fx .GraphModule ,
221
+ ) -> fx .GraphModule :
222
+ node_param_sizes = _analyze_node_size (gm )
223
+ node2stage = split_by_graph_with_num_stages (
224
+ gm , nstages , node_param_sizes
225
+ )
226
+
227
+ # Remove existing split points
228
+ for node in gm .graph .nodes :
229
+ if "pipe_split" in node .name :
230
+ gm .graph .erase_node (node )
231
+
232
+ # Modify the graph by grouping nodes on the same stage and adding
233
+ # pipe_splits between the stages
234
+ node_order = [node for node in gm .graph .nodes if node in node2stage ]
235
+ last_node = None
236
+ for stage_idx in range (nstages ):
237
+ nodes_at_stage = [
238
+ node
239
+ for node in node_order
240
+ if node in node2stage and node2stage [node ] == stage_idx
241
+ ]
242
+ for idx , node in enumerate (nodes_at_stage ):
243
+ if last_node is not None and last_node .next != node :
244
+ last_node .append (node )
245
+ last_node = node
246
+ # Insert pipe_split nodes after each stage, except the last one
247
+ if stage_idx + 1 != nstages and last_node is not None :
248
+ with gm .graph .inserting_after (last_node ):
249
+ last_node = gm .graph .call_function (
250
+ aten_pipe_split_alias , (), {}
251
+ )
252
+
253
+ # Since we transformed the graph, recompile the module
254
+ gm .recompile ()
255
+ gm .graph .lint ()
256
+ return gm
257
+
258
+ return _split_by_graph
0 commit comments