Skip to content

Reuse Valid Accumulated Gradients Upon Failure #192

Open
@WarrenZhu050413

Description

@WarrenZhu050413

Description:

The current failure model in TorchFT handles node failures by zeroing out all accumulated gradients and recalculating them in the subsequent forward/backward pass. This approach, while ensuring correctness, leads to a loss of computational work, as gradients accumulated by surviving replica groups before the failure are discarded, even though they remain valid.

Current Behavior:

Upon failure of a replica group, all gradients are zeroed. The full batch (or adjusted micro-batches for the remaining replicas) is reprocessed, and gradients are recalculated from scratch.

Problem:

This recalculation is inefficient, especially in environments with high failure rates. If failures occur frequently, a significant amount of computation is wasted recomputing gradients that were valid and available on surviving replicas.

Example Scenario:

Consider a setup with 3 replica groups and a global batch size of 24. Each replica group processes a local batch of 8 samples.

  1. Replica group 0 (RG0) processes 8 samples.
  2. Replica group 1 (RG1) processes 8 samples.
  3. Replica group 2 (RG2) processes 8 samples.
  4. RG0 fails before the all_reduce operation.

Currently, the gradients accumulated by RG1 (for 8 samples) and RG2 (for 8 samples) would be zeroed. In the recovery step, the remaining 24 samples (or a subset if data is re-sharded) would need to be reprocessed entirely by the surviving replicas.

Proposed Enhancement:

Instead of zeroing all gradients, TorchFT should identify and preserve the valid gradients accumulated by the surviving replica groups. In the example above, upon RG0's failure, the gradients from RG1 and RG2 (representing 16 samples) should be retained. This could be built on top of #186.

For the recovery iteration, the system would then only need to process the data originally handled by the failed RG0 (8 samples). These samples could be redistributed among the surviving RG1 and RG2 (e.g., 4 samples each). The newly computed gradients for these 8 samples would then be combined with the preserved gradients from the previous partial computation on RG1 and RG2 before the all_reduce operation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions