Skip to content

Fix FT allreduce bug #1170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

Fix FT allreduce bug #1170

wants to merge 1 commit into from

Conversation

H-Huang
Copy link
Member

@H-Huang H-Huang commented May 6, 2025

I was seeing this log / commit failure when using FT: should_commit=False enough_replicas=True, errored='NoneType' object has no attribute 'wait’

Pasted Graphic

It was coming from an FSDP hook which was calling dist.allreduce() the regular c10d collective now returns None which causes the error above. Instead we should be using self.replicate_pg.allreduce(output, opts=ReduceOp.AVG). I think we could also achieve the same thing with manager.allreduce(output), so I'm not sure if the ManagedProcessGroup is needed. Would appreciate any thoughts @fegin @d4l3k

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 6, 2025
@H-Huang H-Huang changed the title Fix allreduce bug Fix FT allreduce bug May 6, 2025
@H-Huang H-Huang marked this pull request as ready for review May 6, 2025 16:10
@H-Huang H-Huang requested review from fegin and d4l3k May 6, 2025 16:10
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -59,7 +58,7 @@ def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]:

def set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None:
def all_reduce_hook(output):
dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG)
self.replicate_pg.allreduce(output, opts=ReduceOp.AVG)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment discussing why we need to use this call instead of the original one? This is less intuitive.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I will add a comment, this is due to these changes (pytorch/pytorch@35c45a4#diff-61109d1cb2a0bd13fc51d678a82666295289da1ec0a1a694e73d9e8c28f51bdcR2885, so the regular dist.all_reduce() is no longer compatible

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang can we fix this in the torchft side? We should just be able to ignore the case when there's no return work object iiuc

Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to fix this in torchft side

@d4l3k
Copy link
Member

d4l3k commented May 6, 2025

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants