Skip to content

Consolidate ZeRO state before checkpoint saving #2623

@danieltudosiu

Description

@danieltudosiu

Is your feature request related to a problem? Please describe.
When checkpoint saving occurs there should be a check if the object that state_dict() is called on is a ZeroRedundancyOptimizer instance

Describe the solution you'd like
Prior to the call state_dict() call a consolidate_state_dict() call should be issued. This call needs to be issued on all ranks and point toward the same consolidating rank.

Here there are two design solutions, you instantiate a Checkpoint handler on all ranks and only on the designated rank does it save the checkpoint or do you create a Handler that needs to run before the Checkpoint handler in order to consolidate.

Special care must be addressed to the PyTorch version as the naming scheme of the ZeRO's method arguments has changed between recent versions.

Describe alternatives you've considered
The only alternative is for the users to write themselves Handlers to do that, segregating the checkpoint-saving logic. And it can be written as:

class ConsolidateZeROHandler:
    """
    Handler that consolidated the Zero Redundancy Optimizer prior to the checkpoint Saving.

    Args:
        zero_optimizer (ZeroRedundancyOptimizer): The optimizer to be consolidated.
        recipient_rank (int): The rank on which the consolidation will happen. Defaults to 0.
        epoch_level (bool): Call every N epochs or every N iterations. `True` is epoch level, `False` is iteration
            level. Defaults to True.
    """

    def __init__(
        self,
        zero_optimizer: ZeroRedundancyOptimizer,
        call_every: int,
        recipient_rank: int = 0,
        epoch_level=True,
    ):
        self.zero_optimizer = zero_optimizer
        self.recipient_rank = recipient_rank
        self.epoch_level = epoch_level
        self.call_every = call_every

    def __call__(self, engine: Engine):
        self.zero_optimizer.consolidate_state_dict(to=self.recipient_rank)

    def attach(self, engine: Engine) -> None:
        if self.epoch_level:
            engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.call_every), self)
        else:
            engine.add_event_handler(Events.ITERATION_COMPLETED(every=self.call_every), self)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions