Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A graph-based pipeline splitting #1080

Merged
merged 3 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added graph-split presolve to speedup computation
  • Loading branch information
spupyrev committed May 29, 2024
commit 6403d71a163dfabc58bccc0222f5d1990aa2d802
9 changes: 8 additions & 1 deletion examples/huggingface/pippy_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run(args):
config.n_embd = args.n_embd or config.n_embd
config.n_layer = args.n_layer or config.n_layer
config.n_head = args.n_head or config.n_head
print("Using device:", args.device)
print("[Rank {}] Using device: {}".format(args.rank, args.device))

# Create model
model_class = GPT2ForSequenceClassification
Expand All @@ -38,13 +38,19 @@ def run(args):
example_inputs = generate_inputs_for_model(
model_class, gpt2, model_name, args.batch_size, args.device)

assert not args.autosplit or not args.graphsplit

split_policy = None
split_spec = None

if args.autosplit:
# Automatic split
from pippy import split_into_equal_size
split_policy = split_into_equal_size(args.world_size)
elif args.graphsplit:
# Graph-based split
from pippy import split_by_graph
split_policy = split_by_graph(args.world_size)
else:
# Use manual split spec
decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size
Expand Down Expand Up @@ -106,6 +112,7 @@ def run(args):
parser.add_argument('--n_layer', type=int, default=None)
parser.add_argument('--n_head', type=int, default=None)
parser.add_argument('--autosplit', action="store_true")
parser.add_argument('--graphsplit', action="store_true")

args = parser.parse_args()

Expand Down
2 changes: 2 additions & 0 deletions pippy/ModelSplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import torch.fx as fx

from pippy.graphsplit import split_by_graph_with_num_stages

from ._IR import aten_pipe_split_alias


Expand Down
Loading
Loading