Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Commit 62fca16

Browse files
author
spupyrev
committed
A graph-based pipeline splitting
1 parent 2aa360f commit 62fca16

File tree

5 files changed

+646
-1
lines changed

5 files changed

+646
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__pycache__
22
build
33
pippy.egg-info
4+
torchpippy.egg-info
45
pippy/version.py
56
dist
67
.idea/

pippy/ModelSplit.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,57 @@ def _split_into_nstages_equal_size(
202202
return gm
203203

204204
return _split_into_nstages_equal_size
205+
206+
207+
"""
208+
Create a Callable that splits a model into a given number of stages, based on the computation graph, while
209+
trying to minimize the communication between the stages and to balance the computation
210+
Input:
211+
nstages: the number of stages to split the module into
212+
Output:
213+
a Callable that transforms an input `fx.GraphModule` into an output `fx.GraphModule` that has `pipe_split` inserted
214+
between `nstages` stages
215+
"""
216+
217+
218+
def split_by_graph(nstages: int) -> Callable[[fx.GraphModule], fx.GraphModule]:
219+
def _split_by_graph(
220+
gm: fx.GraphModule,
221+
) -> fx.GraphModule:
222+
node_param_sizes = _analyze_node_size(gm)
223+
node2stage = split_by_graph_with_num_stages(
224+
gm, nstages, node_param_sizes
225+
)
226+
227+
# Remove existing split points
228+
for node in gm.graph.nodes:
229+
if "pipe_split" in node.name:
230+
gm.graph.erase_node(node)
231+
232+
# Modify the graph by grouping nodes on the same stage and adding
233+
# pipe_splits between the stages
234+
node_order = [node for node in gm.graph.nodes if node in node2stage]
235+
last_node = None
236+
for stage_idx in range(nstages):
237+
nodes_at_stage = [
238+
node
239+
for node in node_order
240+
if node in node2stage and node2stage[node] == stage_idx
241+
]
242+
for idx, node in enumerate(nodes_at_stage):
243+
if last_node is not None and last_node.next != node:
244+
last_node.append(node)
245+
last_node = node
246+
# Insert pipe_split nodes after each stage, except the last one
247+
if stage_idx + 1 != nstages and last_node is not None:
248+
with gm.graph.inserting_after(last_node):
249+
last_node = gm.graph.call_function(
250+
aten_pipe_split_alias, (), {}
251+
)
252+
253+
# Since we transformed the graph, recompile the module
254+
gm.recompile()
255+
gm.graph.lint()
256+
return gm
257+
258+
return _split_by_graph

pippy/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
)
1111
from ._PipelineStage import PipelineStage
1212
from .ManualPipelineStage import ManualPipelineStage
13-
from .ModelSplit import split_into_equal_size, split_on_size_threshold
13+
from .ModelSplit import (
14+
split_by_graph,
15+
split_into_equal_size,
16+
split_on_size_threshold,
17+
)
1418
from .PipelineSchedule import (
1519
Schedule1F1B,
1620
ScheduleGPipe,
@@ -27,6 +31,7 @@
2731
"annotate_split_points",
2832
"split_into_equal_size",
2933
"split_on_size_threshold",
34+
"split_by_graph",
3035
"pipeline",
3136
"Schedule1F1B",
3237
"ScheduleGPipe",

0 commit comments

Comments
 (0)