-
Notifications
You must be signed in to change notification settings - Fork 411
Fix FT allreduce bug #1170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix FT allreduce bug #1170
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -59,7 +58,7 @@ def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]: | |||
|
|||
def set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None: | |||
def all_reduce_hook(output): | |||
dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG) | |||
self.replicate_pg.allreduce(output, opts=ReduceOp.AVG) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a comment discussing why we need to use this call instead of the original one? This is less intuitive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure I will add a comment, this is due to these changes (pytorch/pytorch@35c45a4#diff-61109d1cb2a0bd13fc51d678a82666295289da1ec0a1a694e73d9e8c28f51bdcR2885, so the regular dist.all_reduce()
is no longer compatible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@H-Huang can we fix this in the torchft side? We should just be able to ignore the case when there's no return work object iiuc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's better to fix this in torchft side
I was seeing this log / commit failure when using FT:
should_commit=False enough_replicas=True, errored='NoneType' object has no attribute 'wait’
It was coming from an FSDP hook which was calling
dist.allreduce()
the regular c10d collective now returnsNone
which causes the error above. Instead we should be usingself.replicate_pg.allreduce(output, opts=ReduceOp.AVG)
. I think we could also achieve the same thing withmanager.allreduce(output)
, so I'm not sure if theManagedProcessGroup
is needed. Would appreciate any thoughts @fegin @d4l3k