-
-
Notifications
You must be signed in to change notification settings - Fork 655
Description
Is your feature request related to a problem? Please describe.
When checkpoint saving occurs there should be a check if the object that state_dict() is called on is a ZeroRedundancyOptimizer instance
Describe the solution you'd like
Prior to the call state_dict() call a consolidate_state_dict() call should be issued. This call needs to be issued on all ranks and point toward the same consolidating rank.
Here there are two design solutions, you instantiate a Checkpoint handler on all ranks and only on the designated rank does it save the checkpoint or do you create a Handler that needs to run before the Checkpoint handler in order to consolidate.
Special care must be addressed to the PyTorch version as the naming scheme of the ZeRO's method arguments has changed between recent versions.
Describe alternatives you've considered
The only alternative is for the users to write themselves Handlers to do that, segregating the checkpoint-saving logic. And it can be written as:
class ConsolidateZeROHandler:
"""
Handler that consolidated the Zero Redundancy Optimizer prior to the checkpoint Saving.
Args:
zero_optimizer (ZeroRedundancyOptimizer): The optimizer to be consolidated.
recipient_rank (int): The rank on which the consolidation will happen. Defaults to 0.
epoch_level (bool): Call every N epochs or every N iterations. `True` is epoch level, `False` is iteration
level. Defaults to True.
"""
def __init__(
self,
zero_optimizer: ZeroRedundancyOptimizer,
call_every: int,
recipient_rank: int = 0,
epoch_level=True,
):
self.zero_optimizer = zero_optimizer
self.recipient_rank = recipient_rank
self.epoch_level = epoch_level
self.call_every = call_every
def __call__(self, engine: Engine):
self.zero_optimizer.consolidate_state_dict(to=self.recipient_rank)
def attach(self, engine: Engine) -> None:
if self.epoch_level:
engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.call_every), self)
else:
engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.call_every), self)