Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A graph-based pipeline splitting #1080

Merged
merged 3 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__pycache__
build
pippy.egg-info
torchpippy.egg-info
pippy/version.py
dist
.idea/
Expand Down
9 changes: 8 additions & 1 deletion examples/huggingface/pippy_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run(args):
config.n_embd = args.n_embd or config.n_embd
config.n_layer = args.n_layer or config.n_layer
config.n_head = args.n_head or config.n_head
print("Using device:", args.device)
print("[Rank {}] Using device: {}".format(args.rank, args.device))

# Create model
model_class = GPT2ForSequenceClassification
Expand All @@ -38,13 +38,19 @@ def run(args):
example_inputs = generate_inputs_for_model(
model_class, gpt2, model_name, args.batch_size, args.device)

assert not args.autosplit or not args.graphsplit

split_policy = None
split_spec = None

if args.autosplit:
# Automatic split
from pippy import split_into_equal_size
split_policy = split_into_equal_size(args.world_size)
elif args.graphsplit:
# Graph-based split
from pippy import split_by_graph
split_policy = split_by_graph(args.world_size)
else:
# Use manual split spec
decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size
Expand Down Expand Up @@ -106,6 +112,7 @@ def run(args):
parser.add_argument('--n_layer', type=int, default=None)
parser.add_argument('--n_head', type=int, default=None)
parser.add_argument('--autosplit', action="store_true")
parser.add_argument('--graphsplit', action="store_true")

args = parser.parse_args()

Expand Down
56 changes: 56 additions & 0 deletions pippy/ModelSplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import torch.fx as fx

from pippy.graphsplit import split_by_graph_with_num_stages

from ._IR import aten_pipe_split_alias


Expand Down Expand Up @@ -202,3 +204,57 @@ def _split_into_nstages_equal_size(
return gm

return _split_into_nstages_equal_size


"""
Create a Callable that splits a model into a given number of stages, based on the computation graph, while
trying to minimize the communication between the stages and to balance the computation
Input:
nstages: the number of stages to split the module into
Output:
a Callable that transforms an input `fx.GraphModule` into an output `fx.GraphModule` that has `pipe_split` inserted
between `nstages` stages
"""


def split_by_graph(nstages: int) -> Callable[[fx.GraphModule], fx.GraphModule]:
def _split_by_graph(
gm: fx.GraphModule,
) -> fx.GraphModule:
node_param_sizes = _analyze_node_size(gm)
node2stage = split_by_graph_with_num_stages(
gm, nstages, node_param_sizes
)

# Remove existing split points
for node in gm.graph.nodes:
if "pipe_split" in node.name:
gm.graph.erase_node(node)

# Modify the graph by grouping nodes on the same stage and adding
# pipe_splits between the stages
node_order = [node for node in gm.graph.nodes if node in node2stage]
last_node = None
for stage_idx in range(nstages):
nodes_at_stage = [
node
for node in node_order
if node in node2stage and node2stage[node] == stage_idx
]
for idx, node in enumerate(nodes_at_stage):
if last_node is not None and last_node.next != node:
last_node.append(node)
last_node = node
# Insert pipe_split nodes after each stage, except the last one
if stage_idx + 1 != nstages and last_node is not None:
with gm.graph.inserting_after(last_node):
last_node = gm.graph.call_function(
aten_pipe_split_alias, (), {}
)

# Since we transformed the graph, recompile the module
gm.recompile()
gm.graph.lint()
return gm

return _split_by_graph
7 changes: 6 additions & 1 deletion pippy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
)
from ._PipelineStage import PipelineStage
from .ManualPipelineStage import ManualPipelineStage
from .ModelSplit import split_into_equal_size, split_on_size_threshold
from .ModelSplit import (
split_by_graph,
split_into_equal_size,
split_on_size_threshold,
)
from .PipelineSchedule import (
Schedule1F1B,
ScheduleGPipe,
Expand All @@ -27,6 +31,7 @@
"annotate_split_points",
"split_into_equal_size",
"split_on_size_threshold",
"split_by_graph",
"pipeline",
"Schedule1F1B",
"ScheduleGPipe",
Expand Down
Loading
Loading