Skip to content

Commit cfe2b31

Browse files
author
spupyrev
committed
Added graph-split presolve to speedup computation
1 parent 62fca16 commit cfe2b31

File tree

5 files changed

+186
-63
lines changed

5 files changed

+186
-63
lines changed

examples/huggingface/pippy_gpt2.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def run(args):
2424
config.n_embd = args.n_embd or config.n_embd
2525
config.n_layer = args.n_layer or config.n_layer
2626
config.n_head = args.n_head or config.n_head
27-
print("Using device:", args.device)
27+
print("[Rank {}] Using device: {}".format(args.rank, args.device))
2828

2929
# Create model
3030
model_class = GPT2ForSequenceClassification
@@ -41,13 +41,19 @@ def run(args):
4141
example_inputs = generate_inputs_for_model(
4242
model_class, gpt2, model_name, args.batch_size, args.device)
4343

44+
assert not args.autosplit or not args.graphsplit
45+
4446
split_policy = None
4547
split_spec = None
4648

4749
if args.autosplit:
4850
# Automatic split
4951
from pippy import split_into_equal_size
5052
split_policy = split_into_equal_size(args.world_size)
53+
elif args.graphsplit:
54+
# Graph-based split
55+
from pippy import split_by_graph
56+
split_policy = split_by_graph(args.world_size)
5157
else:
5258
# Use manual split spec
5359
decoders_per_rank = (gpt2.config.n_layer + args.world_size - 1) // args.world_size
@@ -112,6 +118,7 @@ def run(args):
112118
parser.add_argument('--n_layer', type=int, default=None)
113119
parser.add_argument('--n_head', type=int, default=None)
114120
parser.add_argument('--autosplit', action="store_true")
121+
parser.add_argument('--graphsplit', action="store_true")
115122

116123
args = parser.parse_args()
117124

pippy/ModelSplit.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import torch
66
import torch.fx as fx
77

8+
from pippy.graphsplit import split_by_graph_with_num_stages
9+
810
from ._IR import aten_pipe_split_alias
911

1012

pippy/_IR.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -925,10 +925,10 @@ def set_multi_use_param_spec(
925925
if isinstance(multi_use_param_spec, MultiUseParameterConfig):
926926
multi_use_params_qualnames[param] = multi_use_param_spec
927927
elif isinstance(multi_use_param_spec, dict):
928-
multi_use_params_qualnames[
929-
param
930-
] = multi_use_param_spec.get(
931-
param, MultiUseParameterConfig.TRANSMIT
928+
multi_use_params_qualnames[param] = (
929+
multi_use_param_spec.get(
930+
param, MultiUseParameterConfig.TRANSMIT
931+
)
932932
)
933933
else:
934934
raise ValueError(

0 commit comments

Comments
 (0)