From 22de8ae2b9122021dbe88723653a87c4fd566fc9 Mon Sep 17 00:00:00 2001 From: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Date: Thu, 27 Jun 2024 00:15:24 -0700 Subject: [PATCH] [VLM][Bugfix] Make sure that `multi_modal_kwargs` is broadcasted properly (#5880) Signed-off-by: Xiaowei Jiang Signed-off-by: Alvant --- .buildkite/test-pipeline.yaml | 4 +- tests/distributed/test_parallel_state.py | 49 ++++++++++++++++++++++++ vllm/distributed/parallel_state.py | 37 ++++++++++++++---- 3 files changed, 81 insertions(+), 9 deletions(-) create mode 100644 tests/distributed/test_parallel_state.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 10cfe35d85be4..fa37d0c7539f6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -27,7 +27,9 @@ steps: - label: Core Test mirror_hardwares: [amd] - command: pytest -v -s core + commands: + - pytest -v -s core + - pytest -v -s distributed/test_parallel_state.py - label: Distributed Comm Ops Test #mirror_hardwares: [amd] diff --git a/tests/distributed/test_parallel_state.py b/tests/distributed/test_parallel_state.py new file mode 100644 index 0000000000000..5d293b2c16c44 --- /dev/null +++ b/tests/distributed/test_parallel_state.py @@ -0,0 +1,49 @@ +from typing import Any, Dict + +import torch + +from vllm.distributed.parallel_state import (_split_tensor_dict, + _update_nested_dict) + + +def test_split_tensor_dict(): + test_dict = { + "key_a": "a", + "key_b": torch.arange(8, dtype=torch.float32), + "key_c": { + "key_1": torch.arange(5, dtype=torch.float32), + "key_2": torch.tensor([], dtype=torch.float32), + "key_3": 123, + }, + "key_d": {}, + } + metadata_list, tensor_list = _split_tensor_dict(test_dict) + assert len(metadata_list) == 6 + assert torch.allclose(tensor_list[0], test_dict["key_b"]) + assert torch.allclose(tensor_list[1], test_dict["key_c"]["key_1"]) + assert torch.allclose(tensor_list[2], test_dict["key_c"]["key_2"]) + + +def test_update_nested_dict(): + flattened_keys_values = [("key1%key2%key3", "value1"), + ("key1%key2%key4", "value2"), + ("key1%key5", "value3"), ("key6%key7", "value4"), + ("key8", "value5")] + res: Dict[str, Any] = {} + + # Update the nested dictionary with each flattened key-value pair + for flat_key, value in flattened_keys_values: + _update_nested_dict(res, flat_key, value) + assert res == { + "key1": { + "key2": { + "key3": "value1", + "key4": "value2" + }, + "key5": "value3" + }, + "key6": { + "key7": "value4" + }, + "key8": "value5" + } diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a7a806b055681..1f6b05e8631a8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -45,14 +45,17 @@ class GraphCaptureContext: def _split_tensor_dict( - tensor_dict: Dict[Any, Union[torch.Tensor, Any]] -) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + tensor_dict: Dict[Any, Union[torch.Tensor, Any]], + prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced by its metadata. 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". """ - metadata_list = [] + metadata_list: List[Tuple[str, Any]] = [] tensor_list = [] for key, value in tensor_dict.items(): if isinstance(value, torch.Tensor): @@ -62,13 +65,31 @@ def _split_tensor_dict( # receiving side will set the device index. device = value.device.type metadata_list.append( - (key, TensorMetadata(device, value.dtype, value.size()))) + (prefix + key, TensorMetadata(device, value.dtype, + value.size()))) tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%") + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) else: - metadata_list.append((key, value)) + metadata_list.append((prefix + key, value)) return metadata_list, tensor_list +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + class GroupCoordinator: """ PyTorch ProcessGroup wrapper for a group of processes. @@ -512,7 +533,7 @@ def broadcast_tensor_dict( device=value.device) if tensor.numel() == 0: # Skip broadcasting empty tensors. - tensor_dict[key] = tensor + _update_nested_dict(tensor_dict, key, tensor) continue if tensor.is_cpu: # use metadata_group for CPU tensors @@ -528,9 +549,9 @@ def broadcast_tensor_dict( group=group, async_op=True) async_handles.append(handle) - tensor_dict[key] = tensor + _update_nested_dict(tensor_dict, key, tensor) else: - tensor_dict[key] = value + _update_nested_dict(tensor_dict, key, value) for async_handle in async_handles: async_handle.wait() return tensor_dict