Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 74bf88f

Browse files
xwjiang2010ywang96
authored andcommitted
[VLM][BugFix] Make sure that multi_modal_kwargs can broadcast properly with ring buffer. (vllm-project#5905)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Roger Wang <ywang@roblox.com>
1 parent 5f1316e commit 74bf88f

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

vllm/distributed/parallel_state.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class GraphCaptureContext:
4545

4646

4747
def _split_tensor_dict(
48-
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
48+
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
4949
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
5050
"""Split the tensor dictionary into two parts:
5151
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
@@ -473,11 +473,11 @@ def recv_object(self, src: int) -> Any:
473473

474474
def broadcast_tensor_dict(
475475
self,
476-
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
476+
tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
477477
src: int = 0,
478478
group: Optional[ProcessGroup] = None,
479479
metadata_group: Optional[ProcessGroup] = None
480-
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
480+
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
481481
"""Broadcast the input tensor dictionary.
482482
NOTE: `src` is the local rank of the source rank.
483483
"""
@@ -558,9 +558,9 @@ def broadcast_tensor_dict(
558558

559559
def send_tensor_dict(
560560
self,
561-
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
561+
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
562562
dst: Optional[int] = None
563-
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
563+
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
564564
"""Send the input tensor dictionary.
565565
NOTE: `dst` is the local rank of the source rank.
566566
"""
@@ -599,7 +599,7 @@ def send_tensor_dict(
599599
def recv_tensor_dict(
600600
self,
601601
src: Optional[int] = None
602-
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
602+
) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
603603
"""Recv the input tensor dictionary.
604604
NOTE: `src` is the local rank of the source rank.
605605
"""
@@ -615,15 +615,15 @@ def recv_tensor_dict(
615615
assert src < self.world_size, f"Invalid src rank ({src})"
616616

617617
recv_metadata_list = self.recv_object(src=src)
618-
tensor_dict = {}
618+
tensor_dict: Dict[str, Any] = {}
619619
for key, value in recv_metadata_list:
620620
if isinstance(value, TensorMetadata):
621621
tensor = torch.empty(value.size,
622622
dtype=value.dtype,
623623
device=value.device)
624624
if tensor.numel() == 0:
625625
# Skip broadcasting empty tensors.
626-
tensor_dict[key] = tensor
626+
_update_nested_dict(tensor_dict, key, tensor)
627627
continue
628628
if tensor.is_cpu:
629629
# use metadata_group for CPU tensors
@@ -633,9 +633,9 @@ def recv_tensor_dict(
633633
else:
634634
# use group for GPU tensors
635635
torch.distributed.recv(tensor, src=src, group=group)
636-
tensor_dict[key] = tensor
636+
_update_nested_dict(tensor_dict, key, tensor)
637637
else:
638-
tensor_dict[key] = value
638+
_update_nested_dict(tensor_dict, key, value)
639639
return tensor_dict
640640

641641
def barrier(self):

0 commit comments

Comments
 (0)