Skip to content

[Core] Batch multi modal input using pinned memory #19169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,8 @@ def modalities(self):
return self._items_by_modality.keys()

@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
def _try_stack(nested_tensors: NestedTensors,
pin_memory: bool = False) -> NestedTensors:
"""
Stack the inner dimensions that have the same shape in
a nested list of tensors.
Expand All @@ -697,7 +698,9 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors)

stacked = [MultiModalKwargs._try_stack(t) for t in nested_tensors]
stacked = [
MultiModalKwargs._try_stack(t, pin_memory) for t in nested_tensors
]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
return stacked
Expand All @@ -713,10 +716,16 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
# The tensors have incompatible shapes and can't be stacked.
return tensors_

return torch.stack(tensors_)
outputs = torch.empty(len(tensors_),
*tensors_[0].shape,
dtype=tensors_[0].dtype,
device=tensors_[0].device,
pin_memory=pin_memory)
return torch.stack(tensors_, out=outputs)

@staticmethod
def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
def batch(inputs_list: list["MultiModalKwargs"],
pin_memory: bool = False) -> BatchedTensorInputs:
"""
Batch multiple inputs together into a dictionary.

Expand All @@ -738,7 +747,7 @@ def batch(inputs_list: list["MultiModalKwargs"]) -> BatchedTensorInputs:
item_lists[k].append(v)

return {
k: MultiModalKwargs._try_stack(item_list)
k: MultiModalKwargs._try_stack(item_list, pin_memory)
for k, item_list in item_lists.items()
}

Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,8 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):

encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.batch(
grouped_mm_inputs, pin_memory=self.pin_memory)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
device=self.device,
Expand Down Expand Up @@ -1952,7 +1953,8 @@ def profile_run(self) -> None:
).multi_modal_data

batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items)
[dummy_mm_kwargs] * max_num_mm_items,
pin_memory=self.pin_memory)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
device=self.device,
Expand Down