From ef9b636e2d427f588bf11242e312ba8954d9aff0 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 19 Jan 2024 11:23:30 -0800 Subject: [PATCH] Simplify broadcast logic for control messages (#2501) --- tests/distributed/test_comm_ops.py | 35 ++++- .../parallel_utils/communication_op.py | 73 ++++++++- vllm/worker/model_runner.py | 138 ++++-------------- vllm/worker/worker.py | 29 ++-- 4 files changed, 146 insertions(+), 129 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 75111feb39507..b12e563fd9d44 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -11,6 +11,7 @@ from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather, + broadcast_tensor_dict, ) from vllm.worker.worker import _init_distributed_environment @@ -64,11 +65,41 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, assert torch.allclose(t, expected) +@ray.remote(num_gpus=1, max_calls=1) +def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, + distributed_init_port: str): + init_test_distributed_environment(1, tensor_parallel_size, rank, + distributed_init_port) + test_dict = { + "a": torch.arange(8, dtype=torch.float32, device="cuda"), + "b": torch.arange(16, dtype=torch.int8, device="cuda"), + "c": "test", + "d": [1, 2, 3], + "e": { + "a": 1, + "b": 2 + }, + } + + if rank == 0: + broadcast_tensor_dict(test_dict, src=0) + else: + recv_dict = broadcast_tensor_dict(src=0) + assert len(recv_dict) == len(test_dict) + assert torch.allclose(recv_dict["a"], test_dict["a"]) + assert torch.allclose(recv_dict["b"], test_dict["b"]) + assert recv_dict["c"] == test_dict["c"] + assert recv_dict["d"] == test_dict["d"] + assert recv_dict["e"] == test_dict["e"] + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("tensor_parallel_size", [2]) -@pytest.mark.parametrize("test_target", - [all_reduce_test_worker, all_gather_test_worker]) +@pytest.mark.parametrize("test_target", [ + all_reduce_test_worker, all_gather_test_worker, + broadcast_tensor_dict_test_worker +]) def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): # Using ray helps debugging the error when it failed # as compared to multiprocessing. diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 8bf04f3d1f056..64992d05527e8 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -1,3 +1,6 @@ +from collections import namedtuple +from typing import Any, Dict, List, Optional, Union + import torch from vllm.model_executor.parallel_utils.parallel_state import ( @@ -7,7 +10,7 @@ ) -def tensor_model_parallel_all_reduce(input_): +def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group. NOTE: This operation is applied in-place on the input tensor. @@ -21,7 +24,8 @@ def tensor_model_parallel_all_reduce(input_): return input_ -def tensor_model_parallel_all_gather(input_, dim=-1): +def tensor_model_parallel_all_gather(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" world_size = get_tensor_model_parallel_world_size() # Bypass the function if we are using only 1 GPU. @@ -48,7 +52,9 @@ def tensor_model_parallel_all_gather(input_, dim=-1): return output_tensor -def tensor_model_parallel_gather(input_, dst=0, dim=-1): +def tensor_model_parallel_gather(input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> torch.Tensor: """Gather the input tensor across model parallel group. NOTE: We assume that the input tensor is on the same device across @@ -80,7 +86,7 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1): return output_tensor -def broadcast(input_, src=0): +def broadcast(input_: torch.Tensor, src: int = 0): """Broadcast the input tensor.""" world_size = torch.distributed.get_world_size() assert 0 <= src < world_size, f"Invalid src rank ({src})" @@ -93,7 +99,7 @@ def broadcast(input_, src=0): return input_ -def broadcast_object_list(obj_list, src=0): +def broadcast_object_list(obj_list: List[Any], src: int = 0): """Broadcast the input object list.""" world_size = torch.distributed.get_world_size() assert 0 <= src < world_size, f"Invalid src rank ({src})" @@ -104,3 +110,60 @@ def broadcast_object_list(obj_list, src=0): # Broadcast. torch.distributed.broadcast_object_list(obj_list, src=src) return obj_list + + +TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) + + +def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, + Any]]] = None, + src: int = 0) -> Dict[Any, Union[torch.Tensor, Any]]: + """Broadcast the input tensor dictionary.""" + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + assert 0 <= src < world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return tensor_dict + + if rank == src: + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + assert value.is_cuda, ( + f"Tensor {key}: {value} is not on cuda. Currently we only " + f"support broadcasting tensors on cuda.") + metadata_list.append( + (key, TensorMetadata(value.dtype, value.size()))) + else: + metadata_list.append((key, value)) + torch.distributed.broadcast_object_list([metadata_list], src=src) + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = tensor_dict[key] + torch.distributed.broadcast(tensor, src=src) + else: + recv_metadata_list = [None] + torch.distributed.broadcast_object_list(recv_metadata_list, src=src) + metadata_list = recv_metadata_list[0] + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device="cuda") + async_handle = torch.distributed.broadcast(tensor, + src=src, + async_op=True) + async_handles.append(async_handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index d290886506507..8e764e73bf41c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,7 +9,7 @@ from vllm.logger import init_logger from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import ( - broadcast, broadcast_object_list) + broadcast_tensor_dict) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.utils import in_wsl @@ -393,121 +393,43 @@ def prepare_input_tensors( prompt_lens, subquery_lens) - def get_size_or_none(x: Optional[torch.Tensor]): - return x.size() if x is not None else None - - # Broadcast the input data. For input tensors, we first broadcast - # its shape and then broadcast the tensor to avoid high - # serialization cost. - py_data = { - "input_tokens_size": - input_tokens.size(), - "input_positions_size": - input_positions.size(), - "is_prompt": - input_metadata.is_prompt, - "slot_mapping_size": - get_size_or_none(input_metadata.slot_mapping), - "prompt_lens_size": - get_size_or_none(input_metadata.prompt_lens), - "max_seq_len": - input_metadata.max_seq_len, - "start_loc_size": - get_size_or_none(input_metadata.start_loc), - "max_context_len": - input_metadata.max_context_len, - "context_lens_size": - get_size_or_none(input_metadata.context_lens), - "block_tables_size": - get_size_or_none(input_metadata.block_tables), - "use_cuda_graph": - input_metadata.use_cuda_graph, - "selected_token_indices_size": - sampling_metadata.selected_token_indices.size(), + # Broadcast the metadata. + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "is_prompt": input_metadata.is_prompt, + "slot_mapping": input_metadata.slot_mapping, + "prompt_lens": input_metadata.prompt_lens, + "max_seq_len": input_metadata.max_seq_len, + "start_loc": input_metadata.start_loc, + "max_context_len": input_metadata.max_context_len, + "context_lens": input_metadata.context_lens, + "block_tables": input_metadata.block_tables, + "use_cuda_graph": input_metadata.use_cuda_graph, + "selected_token_indices": + sampling_metadata.selected_token_indices, } - broadcast_object_list([py_data], src=0) - # TODO(zhuohan): Combine the broadcasts or set async_op=True. - broadcast(input_tokens, src=0) - broadcast(input_positions, src=0) - if input_metadata.slot_mapping is not None: - broadcast(input_metadata.slot_mapping, src=0) - if input_metadata.prompt_lens is not None: - broadcast(input_metadata.prompt_lens, src=0) - if input_metadata.start_loc is not None: - broadcast(input_metadata.start_loc, src=0) - if input_metadata.context_lens is not None: - broadcast(input_metadata.context_lens, src=0) - if input_metadata.block_tables is not None: - broadcast(input_metadata.block_tables, src=0) - broadcast(sampling_metadata.selected_token_indices, src=0) + broadcast_tensor_dict(metadata_dict, src=0) else: - receving_list = [None] - broadcast_object_list(receving_list, src=0) - py_data = receving_list[0] - input_tokens = torch.empty(*py_data["input_tokens_size"], - dtype=torch.long, - device="cuda") - broadcast(input_tokens, src=0) - input_positions = torch.empty(*py_data["input_positions_size"], - dtype=torch.long, - device="cuda") - broadcast(input_positions, src=0) - if py_data["slot_mapping_size"] is not None: - slot_mapping = torch.empty(*py_data["slot_mapping_size"], - dtype=torch.long, - device="cuda") - broadcast(slot_mapping, src=0) - else: - slot_mapping = None - if py_data["prompt_lens_size"] is not None: - prompt_lens = torch.empty(*py_data["prompt_lens_size"], - dtype=torch.long, - device="cuda") - broadcast(prompt_lens, src=0) - else: - prompt_lens = None - if py_data["start_loc_size"] is not None: - start_loc = torch.empty(*py_data["start_loc_size"], - dtype=torch.long, - device="cuda") - broadcast(start_loc, src=0) - else: - start_loc = None - if py_data["context_lens_size"] is not None: - context_lens = torch.empty(*py_data["context_lens_size"], - dtype=torch.int, - device="cuda") - broadcast(context_lens, src=0) - else: - context_lens = None - if py_data["block_tables_size"] is not None: - block_tables = torch.empty(*py_data["block_tables_size"], - dtype=torch.int, - device="cuda") - broadcast(block_tables, src=0) - else: - block_tables = None - selected_token_indices = torch.empty( - *py_data["selected_token_indices_size"], - dtype=torch.long, - device="cuda") - broadcast(selected_token_indices, src=0) + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict["input_tokens"] + input_positions = metadata_dict["input_positions"] input_metadata = InputMetadata( - is_prompt=py_data["is_prompt"], - slot_mapping=slot_mapping, - prompt_lens=prompt_lens, - max_seq_len=py_data["max_seq_len"], - start_loc=start_loc, - max_context_len=py_data["max_context_len"], - context_lens=context_lens, - block_tables=block_tables, - use_cuda_graph=py_data["use_cuda_graph"], + is_prompt=metadata_dict["is_prompt"], + slot_mapping=metadata_dict["slot_mapping"], + prompt_lens=metadata_dict["prompt_lens"], + max_seq_len=metadata_dict["max_seq_len"], + start_loc=metadata_dict["start_loc"], + max_context_len=metadata_dict["max_context_len"], + context_lens=metadata_dict["context_lens"], + block_tables=metadata_dict["block_tables"], + use_cuda_graph=metadata_dict["use_cuda_graph"], ) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, prompt_lens=None, - selected_token_indices=selected_token_indices, + selected_token_indices=metadata_dict["selected_token_indices"], categorized_sample_indices=None, perform_sampling=False, ) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c2a2ac148085b..70168f4f60e28 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -9,7 +9,7 @@ SchedulerConfig) from vllm.model_executor import set_random_seed from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_object_list) + broadcast_tensor_dict) from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -175,20 +175,21 @@ def execute_model( assert blocks_to_swap_in is not None assert blocks_to_swap_out is not None assert blocks_to_copy is not None - block_swapping_info = [ - blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy - ] - broadcast_object_list([num_seq_groups] + block_swapping_info, - src=0) + data = { + "num_seq_groups": num_seq_groups, + "blocks_to_swap_in": blocks_to_swap_in, + "blocks_to_swap_out": blocks_to_swap_out, + "blocks_to_copy": blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) else: - # num_seq_groups, blocks_to_swap_in, blocks_to_swap_out, - # blocks_to_copy (4 elements) - recv_data = [None] * 4 - broadcast_object_list(recv_data, src=0) - num_seq_groups = recv_data[0] - block_swapping_info = recv_data[1:] - - self.cache_swap(*block_swapping_info) + data = broadcast_tensor_dict(src=0) + num_seq_groups = data["num_seq_groups"] + blocks_to_swap_in = data["blocks_to_swap_in"] + blocks_to_swap_out = data["blocks_to_swap_out"] + blocks_to_copy = data["blocks_to_copy"] + + self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) # If there is no input, we don't need to execute the model. if num_seq_groups == 0: