Skip to content

Commit d12af20

Browse files
authored
[VLM][Bugfix] Make sure that multi_modal_kwargs is broadcasted properly (#5880)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
1 parent 6eabc6c commit d12af20

File tree

3 files changed

+81
-9
lines changed

3 files changed

+81
-9
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ steps:
2727

2828
- label: Core Test
2929
mirror_hardwares: [amd]
30-
command: pytest -v -s core
30+
commands:
31+
- pytest -v -s core
32+
- pytest -v -s distributed/test_parallel_state.py
3133

3234
- label: Distributed Comm Ops Test
3335
#mirror_hardwares: [amd]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Any, Dict
2+
3+
import torch
4+
5+
from vllm.distributed.parallel_state import (_split_tensor_dict,
6+
_update_nested_dict)
7+
8+
9+
def test_split_tensor_dict():
10+
test_dict = {
11+
"key_a": "a",
12+
"key_b": torch.arange(8, dtype=torch.float32),
13+
"key_c": {
14+
"key_1": torch.arange(5, dtype=torch.float32),
15+
"key_2": torch.tensor([], dtype=torch.float32),
16+
"key_3": 123,
17+
},
18+
"key_d": {},
19+
}
20+
metadata_list, tensor_list = _split_tensor_dict(test_dict)
21+
assert len(metadata_list) == 6
22+
assert torch.allclose(tensor_list[0], test_dict["key_b"])
23+
assert torch.allclose(tensor_list[1], test_dict["key_c"]["key_1"])
24+
assert torch.allclose(tensor_list[2], test_dict["key_c"]["key_2"])
25+
26+
27+
def test_update_nested_dict():
28+
flattened_keys_values = [("key1%key2%key3", "value1"),
29+
("key1%key2%key4", "value2"),
30+
("key1%key5", "value3"), ("key6%key7", "value4"),
31+
("key8", "value5")]
32+
res: Dict[str, Any] = {}
33+
34+
# Update the nested dictionary with each flattened key-value pair
35+
for flat_key, value in flattened_keys_values:
36+
_update_nested_dict(res, flat_key, value)
37+
assert res == {
38+
"key1": {
39+
"key2": {
40+
"key3": "value1",
41+
"key4": "value2"
42+
},
43+
"key5": "value3"
44+
},
45+
"key6": {
46+
"key7": "value4"
47+
},
48+
"key8": "value5"
49+
}

vllm/distributed/parallel_state.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,17 @@ class GraphCaptureContext:
4545

4646

4747
def _split_tensor_dict(
48-
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
49-
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
48+
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
49+
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
5050
"""Split the tensor dictionary into two parts:
5151
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
5252
by its metadata.
5353
2. A list of tensors.
54+
55+
If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
56+
metadata will be "key1%key2".
5457
"""
55-
metadata_list = []
58+
metadata_list: List[Tuple[str, Any]] = []
5659
tensor_list = []
5760
for key, value in tensor_dict.items():
5861
if isinstance(value, torch.Tensor):
@@ -62,13 +65,31 @@ def _split_tensor_dict(
6265
# receiving side will set the device index.
6366
device = value.device.type
6467
metadata_list.append(
65-
(key, TensorMetadata(device, value.dtype, value.size())))
68+
(prefix + key, TensorMetadata(device, value.dtype,
69+
value.size())))
6670
tensor_list.append(value)
71+
elif isinstance(value, dict):
72+
if len(value) == 0:
73+
metadata_list.append((prefix + key, value))
74+
inner_metadata_list, inner_tensor_list = _split_tensor_dict(
75+
value, prefix + key + "%")
76+
metadata_list.extend(inner_metadata_list)
77+
tensor_list.extend(inner_tensor_list)
6778
else:
68-
metadata_list.append((key, value))
79+
metadata_list.append((prefix + key, value))
6980
return metadata_list, tensor_list
7081

7182

83+
def _update_nested_dict(nested_dict, flattened_key, value):
84+
key_splits = flattened_key.split("%")
85+
cur_dict = nested_dict
86+
for k in key_splits[:-1]:
87+
if k not in cur_dict:
88+
cur_dict[k] = {}
89+
cur_dict = cur_dict[k]
90+
cur_dict[key_splits[-1]] = value
91+
92+
7293
class GroupCoordinator:
7394
"""
7495
PyTorch ProcessGroup wrapper for a group of processes.
@@ -512,7 +533,7 @@ def broadcast_tensor_dict(
512533
device=value.device)
513534
if tensor.numel() == 0:
514535
# Skip broadcasting empty tensors.
515-
tensor_dict[key] = tensor
536+
_update_nested_dict(tensor_dict, key, tensor)
516537
continue
517538
if tensor.is_cpu:
518539
# use metadata_group for CPU tensors
@@ -528,9 +549,9 @@ def broadcast_tensor_dict(
528549
group=group,
529550
async_op=True)
530551
async_handles.append(handle)
531-
tensor_dict[key] = tensor
552+
_update_nested_dict(tensor_dict, key, tensor)
532553
else:
533-
tensor_dict[key] = value
554+
_update_nested_dict(tensor_dict, key, value)
534555
for async_handle in async_handles:
535556
async_handle.wait()
536557
return tensor_dict

0 commit comments

Comments
 (0)