Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enable backward pass computation and communication overlap by prefetc…
…hing all gather (pytorch#70235) Summary: Pull Request resolved: pytorch#70235 address comments in pytorch#69282: Have fixed a few corner cases for prefetching full parameters in post backward hook. After benchmarking, prefetching full parameters in the pre-backward hook has the best performance and stable but at cost of increased memory; prefetching full parameters in the post-backward hook did not see expected performance, also failed in a few corner cases (fixed) although there is no memory increase. The main issue is that post backward hook fire order is not consistent with opposite of forward computation order, so incorrectly prefetched all gather could delay the really needed all gather in the single NCCL stream and cause some layer's computation delay. So putting these two algorithms as two configurable experimental algorithms for now prefetch full parameters at pre-backward hook: It is observed from past traces that all gather ops are not triggered until current layer's backward pass starts to compute, also for some models previous layers' reduce scatter is scheduled before next layer's all gather ops, since all gather and reduce scatter are in the same nccl stream, this case could result in backward pass has no communication and computation overlap. To explicitly make next layers' all gather scheduled while previous layers' backward computation is running, we can prefetch next layers' all gather full params. This can help 1) both all gather and reduce scatter are overlapped with computation deterministically 2) only prefetch one layer's all gather full parameters, to avoid increasing too much memories. The implementation borrowed the idea from facebookresearch/fairscale#865, where forward graph order is recorded in the forward pass. In the backward pass, this PR prefetches all gather full parameter in current layer's pre-backward hook, instead of prefetching in current layer's post backward hook in facebookresearch/fairscale#865. Also make sure all gather streams are synced properly. Experiments showed 10% memory increase and 20% latency speed up for 1GB roberta model in a slow network environment. Test Plan: unit tests Reviewed By: rohan-varma Differential Revision: D33252795 fbshipit-source-id: 4e2f47225ba223e7429b0dcaa89df3634bb70050
- Loading branch information