Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] refactor pynccl to hold multiple communicators #4591

Merged
merged 44 commits into from
May 10, 2024

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented May 3, 2024

Currently pynccl is bound to the module level instance. And we are using it for just tensor parallel group.

After this refactor, we can create as many pynccl communicator instances as we want, e.g. a new pynccl communicator for pipeline parallel group.

This is an ongoing effort to support pipeline parallel #4412 .

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao Thanks a lot for the PR! The refactoring makes sense.

Left some comments on code style and possible errors in the PR. Please check my review.

tests/distributed/test_pynccl.py Outdated Show resolved Hide resolved
tests/distributed/test_pynccl.py Outdated Show resolved Hide resolved
tests/distributed/test_pynccl.py Outdated Show resolved Hide resolved
tests/distributed/test_pynccl.py Outdated Show resolved Hide resolved
vllm/distributed/communication_op.py Outdated Show resolved Hide resolved
vllm/distributed/communication_op.py Outdated Show resolved Hide resolved
vllm/distributed/parallel_state.py Outdated Show resolved Hide resolved
Comment on lines +93 to +97
# A small all_reduce for warmup.
data = torch.zeros(1)
if torch.cuda.is_available():
data = data.to(device=f"cuda:{local_rank}")
torch.distributed.all_reduce(data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel warmup should not be a part of this method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which method do you think is better?

group = torch.distributed.new_group(ranks, backend=backend)
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks:
_TP_DEVICE_GROUP = group
_TP_CPU_GROUP = cpu_group

from vllm.distributed.device_communicators.pynccl import NCCLCommunicator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, why do we need this lazy import?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lazy import is required here to avoid circular import. vllm.distributed.device_communicators.pynccl will try to import vllm/distributed/parallel_state.py.

Copy link
Collaborator

@WoosukKwon WoosukKwon May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's the case, I think it means it's a bad design tbh. We should use lazy import only to avoid the unnecessary imports, but not to avoid circular imports. Otherwise, the code will be too complicated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can have a better design. The reason why we have this circular import, is because we tried very hard to figure out the default argument for the group (which requires the import from vllm/distributed/parallel_state.py). We can remove this, but it might break some old code. ( I can do it if you think it is good).

@WoosukKwon WoosukKwon removed their assignment May 9, 2024
@youkaichao
Copy link
Member Author

@WoosukKwon one change after our discussion:

c1b1cdb change with pynccl_comm.enable() to with pynccl_comm.change_state(enable=True) . I think this makes more sense.

@youkaichao youkaichao removed the request for review from zhuohan123 May 9, 2024 08:04
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for addressing my review! Looking forward to the planned refactoring!

@youkaichao youkaichao merged commit 208b71b into vllm-project:main May 10, 2024
55 checks passed
@youkaichao youkaichao deleted the bind_pynccl_to_group branch May 10, 2024 02:48
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 19, 2024
[Core][Distributed] refactor pynccl to hold multiple communicators (vllm-project#4591)
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
[Core][Distributed] refactor pynccl to hold multiple communicators (vllm-project#4591)
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
[Core][Distributed] refactor pynccl to hold multiple communicators (vllm-project#4591)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants