Description
Description:
The current failure model in TorchFT handles node failures by zeroing out all accumulated gradients and recalculating them in the subsequent forward/backward pass. This approach, while ensuring correctness, leads to a loss of computational work, as gradients accumulated by surviving replica groups before the failure are discarded, even though they remain valid.
Current Behavior:
Upon failure of a replica group, all gradients are zeroed. The full batch (or adjusted micro-batches for the remaining replicas) is reprocessed, and gradients are recalculated from scratch.
Problem:
This recalculation is inefficient, especially in environments with high failure rates. If failures occur frequently, a significant amount of computation is wasted recomputing gradients that were valid and available on surviving replicas.
Example Scenario:
Consider a setup with 3 replica groups and a global batch size of 24. Each replica group processes a local batch of 8 samples.
- Replica group 0 (RG0) processes 8 samples.
- Replica group 1 (RG1) processes 8 samples.
- Replica group 2 (RG2) processes 8 samples.
- RG0 fails before the
all_reduce
operation.
Currently, the gradients accumulated by RG1 (for 8 samples) and RG2 (for 8 samples) would be zeroed. In the recovery step, the remaining 24 samples (or a subset if data is re-sharded) would need to be reprocessed entirely by the surviving replicas.
Proposed Enhancement:
Instead of zeroing all gradients, TorchFT should identify and preserve the valid gradients accumulated by the surviving replica groups. In the example above, upon RG0's failure, the gradients from RG1 and RG2 (representing 16 samples) should be retained. This could be built on top of #186.
For the recovery iteration, the system would then only need to process the data originally handled by the failed RG0 (8 samples). These samples could be redistributed among the surviving RG1 and RG2 (e.g., 4 samples each). The newly computed gradients for these 8 samples would then be combined with the preserved gradients from the previous partial computation on RG1 and RG2 before the all_reduce
operation.