33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6-
76"""
87LocalSGD
98=========
10-
119This module implements a fault tolerant version of LocalSGD and related methods.
1210"""
13-
14- from typing import Any , Dict , List , Mapping , Optional
11+ import logging
12+ from types import TracebackType
13+ from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional , Type
1514
1615import torch
1716from torch import nn , optim
17+ from torch .nn .parameter import Parameter
18+ from torch .optim .optimizer import Optimizer
19+ from torch .utils .hooks import RemovableHandle
1820
1921from torchft .manager import Manager
2022
23+ logger : logging .Logger = logging .getLogger (__name__ )
24+
2125
22- class LocalSGD ( nn . Module ) :
26+ class LocalSGD :
2327 """
24- LocalSGD is a model wrapper similar to DistributedDataParallel that
28+ LocalSGD is a context manager that
2529 implements the algorithm described in https://arxiv.org/pdf/1805.09767
2630
2731 This will synchronize the model parameters periodically in a fault tolerant
@@ -68,18 +72,14 @@ def __init__(
6872 pin_memory: Whether to pin the memory used for the backup of the model parameters.
6973 """
7074 super ().__init__ ()
71-
7275 self ._manager = manager
7376 self ._model = model
77+ self ._local_optimizer = optimizer
7478 self ._local_step = 0
75- self ._started_step = False
7679 self ._sync_every = sync_every
7780 assert sync_every >= 1 , "sync_every must be greater than or equal to 1"
78-
7981 device = backup_device or torch .device ("cpu" )
80-
8182 self ._backup_parameters : Dict [str , torch .Tensor ] = {}
82-
8383 for name , p in self ._model .named_parameters ():
8484 t = torch .empty (* tuple (p .shape ), dtype = p .dtype , device = device )
8585 if (
@@ -90,95 +90,150 @@ def __init__(
9090 t = t .pin_memory ()
9191 self ._backup_parameters [name ] = t
9292
93+ self ._hooks : List [RemovableHandle ] = []
9394 # Need to copy the parameters to the host to be safe if we are on the first step.
9495 self ._save_parameters ()
9596
96- optimizer .register_step_post_hook (self ._step_post_hook )
97+ def __enter__ (self ) -> "LocalSGD" :
98+ # Add optimizer hook which increments the local step counter and syncs if necessary
99+ self ._hooks .append (
100+ self ._local_optimizer .register_step_post_hook (self ._step_post_hook )
101+ )
102+ return self
103+
104+ def __exit__ (
105+ self ,
106+ exc_type : Optional [Type [BaseException ]],
107+ exc_value : Optional [BaseException ],
108+ traceback : Optional [TracebackType ],
109+ ) -> bool :
110+ # Handle any cleanup or error handling here
111+ if exc_type is not None :
112+ # If an exception occurred, restore parameters
113+ self ._restore_parameters ()
114+ # Clean up hooks
115+ for hook in self ._hooks :
116+ hook .remove ()
117+ self ._hooks .clear ()
118+
119+ return False # Propagate exceptions
97120
98121 def _save_parameters (self ) -> None :
99- # TODO: consider running copy on a separate stream
100- for name , p in self ._model .named_parameters ():
101- self ._backup_parameters [name ].copy_ (p .data , non_blocking = True )
122+ with torch .no_grad ():
123+ # TODO: consider running copy on a separate stream
124+ for name , p in self ._model .named_parameters ():
125+ self ._backup_parameters [name ].copy_ (p .data , non_blocking = True )
102126
103127 def _restore_parameters (self ) -> None :
104- # TODO: consider running copy on a separate stream
105- for name , p in self ._model .named_parameters ():
106- p .data .copy_ (self ._backup_parameters [name ], non_blocking = True )
128+ with torch .no_grad ():
129+ # TODO: consider running copy on a separate stream
130+ for name , p in self ._model .named_parameters ():
131+ p .data .copy_ (self ._backup_parameters [name ], non_blocking = False )
107132
108- # pyre-fixme[14]: support state_dict args
109- def state_dict (self ) -> Dict [str , object ]:
110- """
111- state_dict returns the state_dict from the last time LocalSGD
112- synchronized and not the current weights.
113- """
114- state_dict = self ._model .state_dict ()
115- for name , p in self ._backup_parameters .items ():
116- assert name in state_dict
117- state_dict [name ] = p
118- return state_dict
119-
120- def load_state_dict (
121- self , state_dict : Mapping [str , Any ], strict : bool = True , assign : bool = False
133+ def _step_post_hook (
134+ self , _optim : optim .Optimizer , _args : List [object ], _kwargs : Dict [str , object ]
122135 ) -> None :
123136 """
124- Loads the state dict to the model and the backup parameters.
137+ This hook is registered on the optimizer and is called after the optimizer step.
138+ """
139+ self ._local_step += 1
140+ if self ._local_step >= self ._sync_every :
141+ self .sync ()
125142
126- This must be called while the model weights aren't being modified to
127- avoid corrupting the backup weights.
143+ def sync (self ) -> None :
128144 """
129- self ._model .load_state_dict (state_dict , strict = strict , assign = assign )
130- self ._save_parameters ()
145+ Synchronizes and averages the model weights across the manager.
146+ """
147+ self ._manager .start_quorum ()
148+ self ._perform_sync ()
149+ self ._local_step = 0
131150
132- def forward (self , * args : object , ** kwargs : object ) -> object :
151+ def _perform_sync (self ) -> None :
152+ """
153+ Performs the synchronization of the model weights across the manager.
154+ This method is intended to be overridden by subclasses to implement custom
155+ synchronization logic.
133156 """
134- Run the model parameters.
157+ self ._average ()
158+ if self ._manager .should_commit ():
159+ self ._save_parameters ()
160+ else :
161+ # commit failed, restore from the backup parameters
162+ self ._restore_parameters ()
135163
136- This should be called before the optimizer step.
164+ def _average (self ) -> None :
165+ # TODO: do we need to broadcast buffers like DDP does?
137166
138- This will start the quorum and save the parameters if this is the first step.
139- """
140- if self ._local_step == 0 :
141- self ._manager .start_quorum ()
167+ works = []
168+
169+ for p in self ._model .parameters ():
170+ # TODO: bucketize parameters
171+ works .append (self ._manager .allreduce (p .data .detach ()))
142172
143- self ._started_step = True
173+ for work in works :
174+ work .wait ()
144175
145- return self ._model .forward (* args , ** kwargs )
146176
147- def _step_post_hook (
148- self , _optim : optim .Optimizer , _args : List [object ], _kwargs : Dict [str , object ]
149- ) -> None :
150- """
151- This hook is registered on the optimizer and is called after the optimizer step.
177+ class DiLoCo (LocalSGD ):
178+ """
179+ DiLoCo is a subclass of LocalSGD that overrides the synchronization
180+ mechanism to average and synchronize the pseudogradients (delta of the previous global weight and current local weights).
152181
153- This will call the allreduce on the model weights every sync_every steps.
154- If any errors occur it will restore to the weights from the previous sync.
182+ diloco: https://arxiv.org/pdf/2311.08105
183+ """
155184
156- ``forward`` must be called before this function.
185+ def __init__ (
186+ self ,
187+ manager : Manager ,
188+ model : nn .Module ,
189+ inner_optimizer : optim .Optimizer ,
190+ outer_optimizer : optim .Optimizer ,
191+ sync_every : int ,
192+ backup_device : Optional [torch .device ] = None ,
193+ pin_memory : bool = True ,
194+ ) -> None :
195+ if manager ._use_async_quorum :
196+ raise ValueError (
197+ "Using DiLoCo require synchronous quorum to be enabled. "
198+ "Ensure that the manager is initialized with use_async_quorum=False"
199+ )
200+ super ().__init__ (
201+ manager , model , inner_optimizer , sync_every , backup_device , pin_memory
202+ )
203+ self ._outer_optimizer = outer_optimizer
204+
205+ def _perform_sync (self ) -> None :
206+ """
207+ Overrides the sync method to calculate the pseugradient, average them across the manager group, and
208+ step using the outer optimizer.
157209 """
158- assert self ._started_step , "forward must be called before step"
159- self ._started_step = False
160210
161- self ._local_step += 1
211+ # Set the .grad field of each parameter to its pseudogradient
212+ for name , p in self ._model .named_parameters ():
213+ assert name in self ._backup_parameters
214+ pseudogradient = p .data - self ._backup_parameters [name ]
215+ p .grad = pseudogradient
162216
163- if self ._local_step >= self . _sync_every :
164- self . _local_step = 0
165- self ._average ()
217+ self ._average_grads ()
218+ # Restore the parameters back to the previous state
219+ self ._restore_parameters ()
166220
167- if self ._manager .should_commit ():
168- # save the parameters so we can restore from them later if necessary.
169- self ._save_parameters ()
170- else :
171- # commit failed, restore from the backup parameters
172- self ._restore_parameters ()
173-
174- def _average (self ) -> None :
175- # TODO: do we need to broadcast buffers like DDP does?
221+ if self ._manager .should_commit ():
222+ # Use the outer optimizer to update the model parameters
223+ self ._outer_optimizer .step ()
224+ self ._save_parameters ()
225+ self ._outer_optimizer .zero_grad ()
176226
227+ def _average_grads (self ) -> None :
228+ """
229+ Average the gradients across the diloco group.
230+ """
177231 works = []
178-
179232 for p in self ._model .parameters ():
180- # TODO: bucketize parameters
181- works .append (self ._manager .allreduce (p .data .detach ()))
182-
233+ # Perform allreduce on the pseudogradients
234+ assert p .grad is not None
235+ work = self ._manager .allreduce (p .grad )
236+ works .append (work )
237+ # Wait for all allreduce operations to complete
183238 for work in works :
184239 work .wait ()
0 commit comments