@@ -24,7 +24,7 @@ def run(args):
24
24
config .n_embd = args .n_embd or config .n_embd
25
25
config .n_layer = args .n_layer or config .n_layer
26
26
config .n_head = args .n_head or config .n_head
27
- print ("Using device:" , args .device )
27
+ print ("[Rank {}] Using device: {}" . format ( args . rank , args .device ) )
28
28
29
29
# Create model
30
30
model_class = GPT2ForSequenceClassification
@@ -41,13 +41,19 @@ def run(args):
41
41
example_inputs = generate_inputs_for_model (
42
42
model_class , gpt2 , model_name , args .batch_size , args .device )
43
43
44
+ assert not args .autosplit or not args .graphsplit
45
+
44
46
split_policy = None
45
47
split_spec = None
46
48
47
49
if args .autosplit :
48
50
# Automatic split
49
51
from pippy import split_into_equal_size
50
52
split_policy = split_into_equal_size (args .world_size )
53
+ elif args .graphsplit :
54
+ # Graph-based split
55
+ from pippy import split_by_graph
56
+ split_policy = split_by_graph (args .world_size )
51
57
else :
52
58
# Use manual split spec
53
59
decoders_per_rank = (gpt2 .config .n_layer + args .world_size - 1 ) // args .world_size
@@ -112,6 +118,7 @@ def run(args):
112
118
parser .add_argument ('--n_layer' , type = int , default = None )
113
119
parser .add_argument ('--n_head' , type = int , default = None )
114
120
parser .add_argument ('--autosplit' , action = "store_true" )
121
+ parser .add_argument ('--graphsplit' , action = "store_true" )
115
122
116
123
args = parser .parse_args ()
117
124
0 commit comments