-
-
Notifications
You must be signed in to change notification settings - Fork 9k
Description
Proposal to improve performance
When we use tensor parallel in vLLM, the driver worker need to broadcast some metadata to all workers, such as the input, the lora requests, etc. This functionality is currently implemented in:
vllm/vllm/distributed/communication_op.py
Line 143 in 9c7306a
def broadcast_tensor_dict( |
In essence, it uses torch.distributed.broadcast_object_list
to broadcast a Python object. This function has many overhead. The overall procedure is:

There are three layers of overhead:
- device memory move: pickle works only for cpu memory. so we need to move data from cpu to device back and forth.
- pickled data of multiple objects are concated, leading to one memory copy
- two broadcast operation is needed, one for broadcasting the size of each pickled object, and the other for broadcasting data.
Current vLLM implementation packs the data in a list of size one, thus overhead 2 is eliminated:
vllm/vllm/distributed/communication_op.py
Lines 173 to 175 in 9c7306a
torch.distributed.broadcast_object_list([metadata_list], | |
src=src, | |
group=group) |
To remove overhead 1, we can use CPU operation to broadcast this kind of metadata.
In addition, if we can know the rough size of picked object, we can remove overhead 3 as well. Only one broadcast is required, which is the optimal case for broadcasting a Python object.
I have wrote some benchmark code in https://gist.github.com/youkaichao/b33fcd70286eb45a4a2d5a6dc32d096b and the result is in https://docs.google.com/spreadsheets/d/1c9xgR0fGvm6SROfk7vrjwOZdYnKQk9oOafWK4_KgOyo/edit?usp=sharing .
The short conclusion is:
- using cpu (gloo) to broadcast the data indeed works better than nccl (gpu). For small size metadata, the broadcast time reduces from 400us to 300us.
- if we can estimate the rough size, the broadcast time can be reduced to 100us. That requires us to design the object to be broadcast.
Report of performance regression
No response
Misc discussion on performance
No response
Your current environment (if you think it is necessary)
The output of `python collect_env.py`