-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][Distributed] refactor pynccl to hold multiple communicators #4591
[Core][Distributed] refactor pynccl to hold multiple communicators #4591
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@youkaichao Thanks a lot for the PR! The refactoring makes sense.
Left some comments on code style and possible errors in the PR. Please check my review.
# A small all_reduce for warmup. | ||
data = torch.zeros(1) | ||
if torch.cuda.is_available(): | ||
data = data.to(device=f"cuda:{local_rank}") | ||
torch.distributed.all_reduce(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel warmup should not be a part of this method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which method do you think is better?
vllm/distributed/parallel_state.py
Outdated
group = torch.distributed.new_group(ranks, backend=backend) | ||
cpu_group = torch.distributed.new_group(ranks, backend="gloo") | ||
if rank in ranks: | ||
_TP_DEVICE_GROUP = group | ||
_TP_CPU_GROUP = cpu_group | ||
|
||
from vllm.distributed.device_communicators.pynccl import NCCLCommunicator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, why do we need this lazy import?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lazy import is required here to avoid circular import. vllm.distributed.device_communicators.pynccl
will try to import vllm/distributed/parallel_state.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If that's the case, I think it means it's a bad design tbh. We should use lazy import only to avoid the unnecessary imports, but not to avoid circular imports. Otherwise, the code will be too complicated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we can have a better design. The reason why we have this circular import, is because we tried very hard to figure out the default argument for the group
(which requires the import from vllm/distributed/parallel_state.py
). We can remove this, but it might break some old code. ( I can do it if you think it is good).
@WoosukKwon one change after our discussion: c1b1cdb change |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for addressing my review! Looking forward to the planned refactoring!
[Core][Distributed] refactor pynccl to hold multiple communicators (vllm-project#4591)
[Core][Distributed] refactor pynccl to hold multiple communicators (vllm-project#4591)
[Core][Distributed] refactor pynccl to hold multiple communicators (vllm-project#4591)
Currently pynccl is bound to the module level instance. And we are using it for just tensor parallel group.
After this refactor, we can create as many pynccl communicator instances as we want, e.g. a new pynccl communicator for pipeline parallel group.
This is an ongoing effort to support pipeline parallel #4412 .