Skip to content

Commit

Permalink
Simplify broadcast logic for control messages (vllm-project#2501)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuohan123 authored Jan 19, 2024
1 parent 2709c00 commit ef9b636
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 129 deletions.
35 changes: 33 additions & 2 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce,
tensor_model_parallel_all_gather,
broadcast_tensor_dict,
)
from vllm.worker.worker import _init_distributed_environment

Expand Down Expand Up @@ -64,11 +65,41 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
assert torch.allclose(t, expected)


@ray.remote(num_gpus=1, max_calls=1)
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
test_dict = {
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
"b": torch.arange(16, dtype=torch.int8, device="cuda"),
"c": "test",
"d": [1, 2, 3],
"e": {
"a": 1,
"b": 2
},
}

if rank == 0:
broadcast_tensor_dict(test_dict, src=0)
else:
recv_dict = broadcast_tensor_dict(src=0)
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("test_target",
[all_reduce_test_worker, all_gather_test_worker])
@pytest.mark.parametrize("test_target", [
all_reduce_test_worker, all_gather_test_worker,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
Expand Down
73 changes: 68 additions & 5 deletions vllm/model_executor/parallel_utils/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union

import torch

from vllm.model_executor.parallel_utils.parallel_state import (
Expand All @@ -7,7 +10,7 @@
)


def tensor_model_parallel_all_reduce(input_):
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group.
NOTE: This operation is applied in-place on the input tensor.
Expand All @@ -21,7 +24,8 @@ def tensor_model_parallel_all_reduce(input_):
return input_


def tensor_model_parallel_all_gather(input_, dim=-1):
def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
Expand All @@ -48,7 +52,9 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
return output_tensor


def tensor_model_parallel_gather(input_, dst=0, dim=-1):
def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
Expand Down Expand Up @@ -80,7 +86,7 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1):
return output_tensor


def broadcast(input_, src=0):
def broadcast(input_: torch.Tensor, src: int = 0):
"""Broadcast the input tensor."""
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"
Expand All @@ -93,7 +99,7 @@ def broadcast(input_, src=0):
return input_


def broadcast_object_list(obj_list, src=0):
def broadcast_object_list(obj_list: List[Any], src: int = 0):
"""Broadcast the input object list."""
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"
Expand All @@ -104,3 +110,60 @@ def broadcast_object_list(obj_list, src=0):
# Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src)
return obj_list


TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])


def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
Any]]] = None,
src: int = 0) -> Dict[Any, Union[torch.Tensor, Any]]:
"""Broadcast the input tensor dictionary."""
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"

# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return tensor_dict

if rank == src:
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
assert value.is_cuda, (
f"Tensor {key}: {value} is not on cuda. Currently we only "
f"support broadcasting tensors on cuda.")
metadata_list.append(
(key, TensorMetadata(value.dtype, value.size())))
else:
metadata_list.append((key, value))
torch.distributed.broadcast_object_list([metadata_list], src=src)
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src)
else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list, src=src)
metadata_list = recv_metadata_list[0]
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device="cuda")
async_handle = torch.distributed.broadcast(tensor,
src=src,
async_op=True)
async_handles.append(async_handle)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
138 changes: 30 additions & 108 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import (
broadcast, broadcast_object_list)
broadcast_tensor_dict)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import in_wsl
Expand Down Expand Up @@ -393,121 +393,43 @@ def prepare_input_tensors(
prompt_lens,
subquery_lens)

def get_size_or_none(x: Optional[torch.Tensor]):
return x.size() if x is not None else None

# Broadcast the input data. For input tensors, we first broadcast
# its shape and then broadcast the tensor to avoid high
# serialization cost.
py_data = {
"input_tokens_size":
input_tokens.size(),
"input_positions_size":
input_positions.size(),
"is_prompt":
input_metadata.is_prompt,
"slot_mapping_size":
get_size_or_none(input_metadata.slot_mapping),
"prompt_lens_size":
get_size_or_none(input_metadata.prompt_lens),
"max_seq_len":
input_metadata.max_seq_len,
"start_loc_size":
get_size_or_none(input_metadata.start_loc),
"max_context_len":
input_metadata.max_context_len,
"context_lens_size":
get_size_or_none(input_metadata.context_lens),
"block_tables_size":
get_size_or_none(input_metadata.block_tables),
"use_cuda_graph":
input_metadata.use_cuda_graph,
"selected_token_indices_size":
sampling_metadata.selected_token_indices.size(),
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"is_prompt": input_metadata.is_prompt,
"slot_mapping": input_metadata.slot_mapping,
"prompt_lens": input_metadata.prompt_lens,
"max_seq_len": input_metadata.max_seq_len,
"start_loc": input_metadata.start_loc,
"max_context_len": input_metadata.max_context_len,
"context_lens": input_metadata.context_lens,
"block_tables": input_metadata.block_tables,
"use_cuda_graph": input_metadata.use_cuda_graph,
"selected_token_indices":
sampling_metadata.selected_token_indices,
}
broadcast_object_list([py_data], src=0)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast(input_tokens, src=0)
broadcast(input_positions, src=0)
if input_metadata.slot_mapping is not None:
broadcast(input_metadata.slot_mapping, src=0)
if input_metadata.prompt_lens is not None:
broadcast(input_metadata.prompt_lens, src=0)
if input_metadata.start_loc is not None:
broadcast(input_metadata.start_loc, src=0)
if input_metadata.context_lens is not None:
broadcast(input_metadata.context_lens, src=0)
if input_metadata.block_tables is not None:
broadcast(input_metadata.block_tables, src=0)
broadcast(sampling_metadata.selected_token_indices, src=0)
broadcast_tensor_dict(metadata_dict, src=0)
else:
receving_list = [None]
broadcast_object_list(receving_list, src=0)
py_data = receving_list[0]
input_tokens = torch.empty(*py_data["input_tokens_size"],
dtype=torch.long,
device="cuda")
broadcast(input_tokens, src=0)
input_positions = torch.empty(*py_data["input_positions_size"],
dtype=torch.long,
device="cuda")
broadcast(input_positions, src=0)
if py_data["slot_mapping_size"] is not None:
slot_mapping = torch.empty(*py_data["slot_mapping_size"],
dtype=torch.long,
device="cuda")
broadcast(slot_mapping, src=0)
else:
slot_mapping = None
if py_data["prompt_lens_size"] is not None:
prompt_lens = torch.empty(*py_data["prompt_lens_size"],
dtype=torch.long,
device="cuda")
broadcast(prompt_lens, src=0)
else:
prompt_lens = None
if py_data["start_loc_size"] is not None:
start_loc = torch.empty(*py_data["start_loc_size"],
dtype=torch.long,
device="cuda")
broadcast(start_loc, src=0)
else:
start_loc = None
if py_data["context_lens_size"] is not None:
context_lens = torch.empty(*py_data["context_lens_size"],
dtype=torch.int,
device="cuda")
broadcast(context_lens, src=0)
else:
context_lens = None
if py_data["block_tables_size"] is not None:
block_tables = torch.empty(*py_data["block_tables_size"],
dtype=torch.int,
device="cuda")
broadcast(block_tables, src=0)
else:
block_tables = None
selected_token_indices = torch.empty(
*py_data["selected_token_indices_size"],
dtype=torch.long,
device="cuda")
broadcast(selected_token_indices, src=0)
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict["input_tokens"]
input_positions = metadata_dict["input_positions"]
input_metadata = InputMetadata(
is_prompt=py_data["is_prompt"],
slot_mapping=slot_mapping,
prompt_lens=prompt_lens,
max_seq_len=py_data["max_seq_len"],
start_loc=start_loc,
max_context_len=py_data["max_context_len"],
context_lens=context_lens,
block_tables=block_tables,
use_cuda_graph=py_data["use_cuda_graph"],
is_prompt=metadata_dict["is_prompt"],
slot_mapping=metadata_dict["slot_mapping"],
prompt_lens=metadata_dict["prompt_lens"],
max_seq_len=metadata_dict["max_seq_len"],
start_loc=metadata_dict["start_loc"],
max_context_len=metadata_dict["max_context_len"],
context_lens=metadata_dict["context_lens"],
block_tables=metadata_dict["block_tables"],
use_cuda_graph=metadata_dict["use_cuda_graph"],
)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
prompt_lens=None,
selected_token_indices=selected_token_indices,
selected_token_indices=metadata_dict["selected_token_indices"],
categorized_sample_indices=None,
perform_sampling=False,
)
Expand Down
29 changes: 15 additions & 14 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
SchedulerConfig)
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_object_list)
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
Expand Down Expand Up @@ -175,20 +175,21 @@ def execute_model(
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
block_swapping_info = [
blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy
]
broadcast_object_list([num_seq_groups] + block_swapping_info,
src=0)
data = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
# blocks_to_copy (4 elements)
recv_data = [None] * 4
broadcast_object_list(recv_data, src=0)
num_seq_groups = recv_data[0]
block_swapping_info = recv_data[1:]

self.cache_swap(*block_swapping_info)
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]

self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)

# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
Expand Down

0 comments on commit ef9b636

Please sign in to comment.