Skip to content

Commit

Permalink
[VLM][Bugfix] Make sure that multi_modal_kwargs is broadcasted prop…
Browse files Browse the repository at this point in the history
…erly (#5880)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
  • Loading branch information
xwjiang2010 authored Jun 27, 2024
1 parent 6eabc6c commit d12af20
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 9 deletions.
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ steps:

- label: Core Test
mirror_hardwares: [amd]
command: pytest -v -s core
commands:
- pytest -v -s core
- pytest -v -s distributed/test_parallel_state.py

- label: Distributed Comm Ops Test
#mirror_hardwares: [amd]
Expand Down
49 changes: 49 additions & 0 deletions tests/distributed/test_parallel_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Any, Dict

import torch

from vllm.distributed.parallel_state import (_split_tensor_dict,
_update_nested_dict)


def test_split_tensor_dict():
test_dict = {
"key_a": "a",
"key_b": torch.arange(8, dtype=torch.float32),
"key_c": {
"key_1": torch.arange(5, dtype=torch.float32),
"key_2": torch.tensor([], dtype=torch.float32),
"key_3": 123,
},
"key_d": {},
}
metadata_list, tensor_list = _split_tensor_dict(test_dict)
assert len(metadata_list) == 6
assert torch.allclose(tensor_list[0], test_dict["key_b"])
assert torch.allclose(tensor_list[1], test_dict["key_c"]["key_1"])
assert torch.allclose(tensor_list[2], test_dict["key_c"]["key_2"])


def test_update_nested_dict():
flattened_keys_values = [("key1%key2%key3", "value1"),
("key1%key2%key4", "value2"),
("key1%key5", "value3"), ("key6%key7", "value4"),
("key8", "value5")]
res: Dict[str, Any] = {}

# Update the nested dictionary with each flattened key-value pair
for flat_key, value in flattened_keys_values:
_update_nested_dict(res, flat_key, value)
assert res == {
"key1": {
"key2": {
"key3": "value1",
"key4": "value2"
},
"key5": "value3"
},
"key6": {
"key7": "value4"
},
"key8": "value5"
}
37 changes: 29 additions & 8 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,17 @@ class GraphCaptureContext:


def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
metadata will be "key1%key2".
"""
metadata_list = []
metadata_list: List[Tuple[str, Any]] = []
tensor_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
Expand All @@ -62,13 +65,31 @@ def _split_tensor_dict(
# receiving side will set the device index.
device = value.device.type
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
(prefix + key, TensorMetadata(device, value.dtype,
value.size())))
tensor_list.append(value)
elif isinstance(value, dict):
if len(value) == 0:
metadata_list.append((prefix + key, value))
inner_metadata_list, inner_tensor_list = _split_tensor_dict(
value, prefix + key + "%")
metadata_list.extend(inner_metadata_list)
tensor_list.extend(inner_tensor_list)
else:
metadata_list.append((key, value))
metadata_list.append((prefix + key, value))
return metadata_list, tensor_list


def _update_nested_dict(nested_dict, flattened_key, value):
key_splits = flattened_key.split("%")
cur_dict = nested_dict
for k in key_splits[:-1]:
if k not in cur_dict:
cur_dict[k] = {}
cur_dict = cur_dict[k]
cur_dict[key_splits[-1]] = value


class GroupCoordinator:
"""
PyTorch ProcessGroup wrapper for a group of processes.
Expand Down Expand Up @@ -512,7 +533,7 @@ def broadcast_tensor_dict(
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
_update_nested_dict(tensor_dict, key, tensor)
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
Expand All @@ -528,9 +549,9 @@ def broadcast_tensor_dict(
group=group,
async_op=True)
async_handles.append(handle)
tensor_dict[key] = tensor
_update_nested_dict(tensor_dict, key, tensor)
else:
tensor_dict[key] = value
_update_nested_dict(tensor_dict, key, value)
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
Expand Down

0 comments on commit d12af20

Please sign in to comment.