Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use NCCL instead of ray for control-plane communication to remove serialization overhead #2221

Merged
merged 35 commits into from
Jan 3, 2024

Conversation

zhuohan123
Copy link
Member

@zhuohan123 zhuohan123 commented Dec 20, 2023

This PR modifies vLLM to use NCCL instead of ray for control-plane communication. The architectural change of vLLM with this PR can be summarized in the following figure:

image

Before this change, vLLM has one driver process (only on GPU), and N worker processes, each of which is a ray actor that manages one GPU. After this change, we will move one worker into the driver process (driver worker), and keep N-1 ray workers. All the control messages will be broadcast from the driver worker to all the remaining workers with NCCL. This avoids the high sterilization cost of ray communication.

For the throughput benchmark of LLaMA-70B on 8 A100-40G GPUs:

On ShareGPT Dataset
  Before this PR: 3.01 reqs/s
  With this PR: 5.08 reqs/s
With Batch size 512, input len 1031, output len 317
  Before this PR: 2.48 reqs/s
  With this PR: 3.48 reqs/s

This PR has passed most of the tests and is ready for review.

Should be merged after #2270, #2273.

@zhuohan123 zhuohan123 changed the title Do not use ray for collective communication to remove serialization overhead [WIP] Do not use ray for collective communication to remove serialization overhead Dec 20, 2023
@zhuohan123 zhuohan123 changed the title [WIP] Do not use ray for collective communication to remove serialization overhead [WIP] Do not use ray for control-plane communication to remove serialization overhead Dec 21, 2023

def broadcast(input_, src=0):
"""Broadcast the input tensor."""
world_size = torch.distributed.get_world_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a function. You need to call it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅 great catch. Will fix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lucky catch. Browsed through the code in 10s.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, When can this branch be merged into the main branch, and will it bring significant performance improvements?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Lvjinhong Please find the performance numbers in the description of this PR. This PR is waiting for review and will be merged soon.

@zhuohan123 zhuohan123 changed the title [WIP] Do not use ray for control-plane communication to remove serialization overhead Use NCCL instead of ray for control-plane communication to remove serialization overhead Dec 26, 2023
@esmeetu
Copy link
Collaborator

esmeetu commented Dec 27, 2023

Hi @zhuohan123, this PR doesn't work for me.
python -m vllm.entrypoints.openai.api_server --model /Llama-2-7b-chat-hf --tensor-parallel-size 4 --dtype half --enforce-eager


  File "/home/roy/vllm/vllm/engine/llm_engine.py", line 599, in _process_model_outputs
    for seq_group, outputs in zip(scheduled_seq_groups, output):
TypeError: 'NoneType' object is not iterable

@zhuohan123
Copy link
Member Author

Hi @zhuohan123, this PR doesn't work for me. python -m vllm.entrypoints.openai.api_server --model /Llama-2-7b-chat-hf --tensor-parallel-size 4 --dtype half --enforce-eager


  File "/home/roy/vllm/vllm/engine/llm_engine.py", line 599, in _process_model_outputs
    for seq_group, outputs in zip(scheduled_seq_groups, output):
TypeError: 'NoneType' object is not iterable

@esmeetu Sorry this is a bug! Can you test the latest commit again?

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhuohan123 Awesome! Thanks for addressing my comments. Let's merge this asap after addressing the remaining minor comments from @njhill!

vllm/model_executor/layers/sampler.py Show resolved Hide resolved
vllm/model_executor/layers/sampler.py Show resolved Hide resolved
vllm/worker/model_runner.py Show resolved Hide resolved
vllm/engine/ray_utils.py Outdated Show resolved Hide resolved
vllm/engine/ray_utils.py Outdated Show resolved Hide resolved
@Yard1
Copy link
Collaborator

Yard1 commented Jan 3, 2024

@zhuohan123 Can you give me a few hours to look over this? Thanks!

Comment on lines 120 to 127
placement_group_specs = ([{
"GPU": 1,
"node:__internal_head__": 0.01
}] + [{
"GPU": 1
}] * parallel_config.world_size)
}] * (parallel_config.world_size - 1))
current_placement_group = ray.util.placement_group(
placement_group_specs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 In this PR, we assume the placement group will have one bundle that includes "node:internal_head". This is to reserve a GPU for the driver worker. This may conflict with some of the existing logic using placement groups.

Comment on lines 168 to 176
if (bundle.get("node:__internal_head__", 0) > 0
and self.driver_dummy_worker is None):
self.driver_dummy_worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerVllm).remote()
continue
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 This is where we use the bundle with the node:__internal_head__ resource to hold the resource for the driver worker.

@zhuohan123
Copy link
Member Author

@zhuohan123 Can you give me a few hours to look over this? Thanks!

Sure! Just highlighted several places where we changed the logic of how we use placement groups, which I think should be important for you to take a look.

vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
vllm/engine/ray_utils.py Outdated Show resolved Hide resolved
@WoosukKwon WoosukKwon mentioned this pull request Jan 3, 2024
2 tasks
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
self.workers.append(worker)

worker_ip = ray.get(worker.get_node_ip.remote())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a minor optimization you can make it another loop so that the Workers can be initialized in a non-blocking fashion but considering that there's nothing really happening in the __init__ I think it's ok to leave it in (though it is an anti-pattern).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah also this only happens once so I think this should not relate to the performance.

Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looks good!

@zhuohan123 zhuohan123 merged commit fd4ea8e into main Jan 3, 2024
2 checks passed
jedibrillo pushed a commit to jedibrillo/vllm that referenced this pull request Jan 5, 2024
quanliu1991 added a commit to quanliu1991/vllm that referenced this pull request Jan 6, 2024
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
@zhuohan123 zhuohan123 deleted the remove-serialization-overhead branch February 22, 2024 18:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants