1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ import logging
8+ import os
9+ import sys
10+ from datetime import timedelta
11+
12+ REPLICA_GROUP_ID = int (os .environ .get ("REPLICA_GROUP_ID" , 0 ))
13+ os .environ ["CUDA_VISIBLE_DEVICES" ] = str (REPLICA_GROUP_ID % 4 )
14+ os .environ ["NCCL_HOSTID" ] = str (REPLICA_GROUP_ID )
15+
16+ import torch
17+ import torchvision
18+ import torchvision .transforms as transforms
19+ from torch import nn , optim
20+ from torch .distributed .elastic .multiprocessing .errors import record
21+ from torchdata .stateful_dataloader import StatefulDataLoader
22+ import time
23+ from torchft import (
24+ DistributedSampler ,
25+ Manager ,
26+ ProcessGroupGloo ,
27+ ProcessGroupNCCL ,
28+ )
29+ from torchft .local_sgd import DiLoCo
30+ from torchft .checkpointing .pg_transport import PGTransport
31+
32+ logging .basicConfig (level = logging .INFO )
33+
34+
35+ @record
36+ def main () -> None :
37+ REPLICA_GROUP_ID = int (os .environ .get ("REPLICA_GROUP_ID" , 0 ))
38+ NUM_REPLICA_GROUPS = int (os .environ .get ("NUM_REPLICA_GROUPS" , 2 ))
39+
40+ transform = transforms .Compose (
41+ [transforms .ToTensor (), transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))]
42+ )
43+ trainset = torchvision .datasets .CIFAR10 (
44+ root = "./cifar" , train = True , download = True , transform = transform
45+ )
46+
47+ # This shards the training set across all ranks and replica groups. We manage
48+ # the dataloaders on a per replica group basis with the assumption that the
49+ # majority of groups will be available so few batches will be dropped.
50+ sampler = DistributedSampler (
51+ trainset ,
52+ replica_group_id = REPLICA_GROUP_ID ,
53+ num_replica_groups = NUM_REPLICA_GROUPS ,
54+ group_rank = 0 ,
55+ # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.
56+ num_replicas = 1 ,
57+ shuffle = True ,
58+ )
59+
60+ # This uses the torchdata StatefulDataLoader to be able to checkpoint and
61+ # restore the per worker dataloader position.
62+ trainloader = StatefulDataLoader (
63+ trainset , batch_size = 64 , num_workers = 2 , sampler = sampler
64+ )
65+
66+
67+ device = "cuda" if torch .cuda .is_available () else "cpu"
68+ pg = (
69+ ProcessGroupNCCL (
70+ timeout = timedelta (seconds = 30 ),
71+ )
72+ if torch .cuda .is_available ()
73+ else ProcessGroupGloo (timeout = timedelta (seconds = 5 ))
74+ )
75+
76+ transport = PGTransport (
77+ pg ,
78+ timeout = timedelta (seconds = 10 ),
79+ device = ("cuda" if torch .cuda .is_available () else "cpu" ),
80+ )
81+
82+ class Net (nn .Module ):
83+ def __init__ (self ):
84+ super ().__init__ ()
85+ self .cnn = nn .Sequential (
86+ nn .Conv2d (3 , 6 , 5 ),
87+ nn .ReLU (),
88+ nn .MaxPool2d (2 , 2 ),
89+ nn .Conv2d (6 , 16 , 5 ),
90+ nn .ReLU (),
91+ nn .MaxPool2d (2 , 2 ),
92+ )
93+
94+ final_dim = 10
95+ # We add a useless 1GB intermediate layer so we spend more time in dist
96+ # communication so injected failures are more likely to cause issues
97+ # if they exist.
98+ target_size = 1_000_000_000
99+ self .useless = nn .Embedding (target_size // final_dim // 4 , final_dim )
100+
101+ self .classifier = nn .Sequential (
102+ nn .Linear (16 * 5 * 5 , 120 ),
103+ nn .ReLU (),
104+ nn .Linear (120 , 84 ),
105+ nn .ReLU (),
106+ nn .Linear (84 , final_dim ),
107+ )
108+
109+ def forward (self , x ):
110+ x = self .cnn (x )
111+ x = torch .flatten (x , 1 ) # flatten all dimensions except batch
112+ x = self .classifier (x )
113+ x += self .useless .weight [0 ]
114+ return x
115+
116+ m = Net ().to (device )
117+ inner_optimizer = optim .AdamW (
118+ m .parameters (), lr = 4e-4 , weight_decay = 0.1 , betas = (0.9 , 0.95 )
119+ )
120+ outer_optimizer = optim .SGD (
121+ m .parameters (), lr = 0.7 , momentum = 0.9 , nesterov = True
122+ )
123+ criterion = nn .CrossEntropyLoss ()
124+
125+ def load_state_dict (state_dict ):
126+ m .load_state_dict (state_dict ["model" ])
127+ m .to (device )
128+ diloco .original_parameters = state_dict ["original_params" ]
129+ for name in diloco .original_parameters .keys ():
130+ diloco .original_parameters [name ] = diloco .original_parameters [name ].to (
131+ device
132+ )
133+ inner_optimizer .load_state_dict (state_dict ["inner_optim" ])
134+ outer_optimizer .load_state_dict (state_dict ["outer_optim" ])
135+
136+ def state_dict ():
137+ return {
138+ "model" : m .state_dict (),
139+ "original_params" : diloco .original_parameters ,
140+ "inner_optim" : inner_optimizer .state_dict (),
141+ "outer_optim" : outer_optimizer .state_dict (),
142+ }
143+
144+ manager = Manager (
145+ pg = pg ,
146+ min_replica_size = 1 ,
147+ load_state_dict = load_state_dict ,
148+ state_dict = state_dict ,
149+ replica_id = f"train_ddp_{ REPLICA_GROUP_ID } " ,
150+ timeout = timedelta (seconds = 30 ),
151+ checkpoint_transport = transport ,
152+ use_async_quorum = False
153+ )
154+
155+ print (m )
156+ num_params = sum (p .numel () for p in m .parameters ())
157+ print (f"Total number of parameters: { num_params } " )
158+
159+ sort_by_keyword = "self_" + device + "_time_total"
160+
161+ def trace_handler (p ):
162+ output = p .key_averages ().table (
163+ sort_by = sort_by_keyword ,
164+ row_limit = 100 ,
165+ )
166+ print (output )
167+ p .export_chrome_trace ("/tmp/trace_" + str (p .step_num ) + ".json" )
168+
169+ # You can use an epoch based training but with faults it's easier to use step
170+ # based training.
171+ prof = torch .profiler .profile (
172+ schedule = torch .profiler .schedule (wait = 5 , warmup = 1 , active = 10 , repeat = 2 ),
173+ on_trace_ready = trace_handler ,
174+ record_shapes = True ,
175+ profile_memory = True ,
176+ )
177+
178+ prof .start ()
179+
180+ num_local_steps = 0
181+ sync_every = 100
182+ with DiLoCo (
183+ manager ,
184+ m ,
185+ inner_optimizer ,
186+ outer_optimizer ,
187+ backup_device = device ,
188+ sync_every = sync_every ,
189+ ) as diloco :
190+ while True :
191+ for i , (inputs , labels ) in enumerate (trainloader ):
192+ prof .step ()
193+
194+ inputs = inputs .to (device )
195+ labels = labels .to (device )
196+
197+ # must be called at the beginning of each train loop
198+ # Quorum computation is triggered here but only needed in the backwards pass.
199+ inner_optimizer .zero_grad ()
200+
201+ out = m (inputs )
202+ loss = criterion (out , labels )
203+
204+ # Gradient allreduce overlaps with the backwards pass.
205+ loss .backward ()
206+
207+ # must be called at the end of the train loop
208+ # This may not actually step the optimizer if an error occured during grad allreduce.
209+ inner_optimizer .step ()
210+ num_local_steps += 1
211+
212+ if manager .current_step () % 100 == 0 :
213+ print (f"[{ manager .current_step ()} ] loss = { loss .item ()} " )
214+
215+ if num_local_steps % sync_every == 0 :
216+ print (f"Number of inner optimizer steps completed: { num_local_steps } " )
217+
218+ # TODO (by the user): periodically checkpoint model, optim, manager and dataloader
219+
220+ # You typically want to checkpoint dataloader frequently (every step?) to
221+ # avoid repeated batches as it's replica group specific.
222+
223+ # Model, optim and manager checkpoints can be done more infrequently as
224+ # they're shared across all groups and will load from existing replicas as
225+ # long as not every worker goes down.
226+
227+ if manager .current_step () >= 10000 :
228+ # complete training
229+ prof .stop ()
230+ exit ()
231+ time .sleep (0.01 )
232+
233+
234+ if __name__ == "__main__" :
235+ main ()
0 commit comments