-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix overlap communication of ZeRO stage 1 and 2 #5606
Conversation
6a001f9
to
1319870
Compare
@microsoft-github-policy-service agree company="Huawei" |
Hi @tjruwase @GuanhuaWang , would you please help to review this PR? |
Hi @penn513 thx for the nice figure and pr. I think your fix on compute stream wait back reduce stream make sense to me, especially when compute is shorter. To make PR more concise, could you remove your modification on npu fuseadam in current pr, and make a new pr on fused adam? (mainly because it is irrelevant with this PR title as "fix overlap communication...") |
Co-authored-by: CurryRice233 <nmeia@qq.com>
Thanks for your reply. It's been updated. |
`deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.average_tensor` only sets reduction stream waiting for default stream. This is ok in cases where the computation time is longer than the communication time, but when the communication time is longer, it may result in a rewrite of the ipg_buffer when the communication is not completed. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/950cbf8a-f439-4cf9-a364-dcdfd47f46a0) To fix this bug, the easiest way is just add default stream to wait for reduction stream at the **same point**. For example, in point 1, the `reduction stream` needs to wait for '2', so we add a wait_stream to `reduction stream` waiting for `default stream`. Also, the `default stream` needs to wait for 'A', so we need to add a wait_stream to `default stream` waiting for `reduction stream` before the 'B'. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/588a9469-d3f9-4c39-976d-3ae0502cf1d1) Compared with the modification of deepspeedai#5523, wait_stream does not cause host synchronization. Compared with the modification of deepspeedai#5545, the modification is more simple and the logic is the same, just waiting for what needs to wait. --- With this modification, losses of Qwen-1.5 with and without overlap_comm are totally identical. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/4d48d54e-e55b-4230-8b99-93549910a43f) --- On the contrary, there is an obvious gap with a small sequence length, which means a short computation time. ![image](https://github.com/microsoft/DeepSpeed/assets/35059704/c80af498-3358-4e36-9b13-8f266551d51d) Co-authored-by: gp513 <guopeng34@huawei.com> Co-authored-by: CurryRice233 <nmeia@qq.com> Co-authored-by: Joe Mayer <114769929+jomayeri@users.noreply.github.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer.average_tensor
only sets reduction stream waiting for default stream. This is ok in cases where the computation time is longer than the communication time, but when the communication time is longer, it may result in a rewrite of the ipg_buffer when the communication is not completed.To fix this bug, the easiest way is just add default stream to wait for reduction stream at the same point. For example, in point 1, the
reduction stream
needs to wait for '2', so we add a wait_stream toreduction stream
waiting fordefault stream
. Also, thedefault stream
needs to wait for 'A', so we need to add a wait_stream todefault stream
waiting forreduction stream
before the 'B'.Compared with the modification of #5523, wait_stream does not cause host synchronization.
Compared with the modification of #5545, the modification is more simple and the logic is the same, just waiting for what needs to wait.
With this modification, losses of Qwen-1.5 with and without overlap_comm are totally identical.
On the contrary, there is an obvious gap with a small sequence length, which means a short computation time.