Skip to content

Commit 7a831ad

Browse files
DarkLight1337shreyankg
authored andcommitted
[VLM] Check required fields before initializing field config in DictEmbeddingItems (vllm-project#13380)
1 parent 4b5c19d commit 7a831ad

File tree

5 files changed

+35
-22
lines changed

5 files changed

+35
-22
lines changed

docs/source/serving/multimodal_inputs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={
184184
mm_data = {
185185
"image": {
186186
"image_embeds": image_embeds,
187-
# image_size_list is needed to calculate details of the sliced image.
188-
"image_size_list": [image.size for image in images], # list of image sizes
187+
# image_sizes is needed to calculate details of the sliced image.
188+
"image_sizes": [image.size for image in images], # list of image sizes
189189
}
190190
}
191191

vllm/model_executor/models/minicpmo.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
# limitations under the License.
2424
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
2525
from functools import partial
26-
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
27-
Tuple, TypedDict, Union)
26+
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
27+
Optional, Set, Tuple, TypedDict, Union)
2828

2929
import torch
3030
from torch import nn
@@ -122,13 +122,16 @@ class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
122122
def __init__(
123123
self,
124124
data: Mapping[str, torch.Tensor],
125-
fields_config: Mapping[str, MultiModalFieldConfig],
125+
fields_factory: Callable[
126+
[Mapping[str, torch.Tensor]],
127+
Mapping[str, MultiModalFieldConfig],
128+
],
126129
) -> None:
127130
super().__init__(
128131
data,
129132
modality="image",
130-
fields_config=fields_config,
131133
required_fields={"audio_embeds"},
134+
fields_factory=fields_factory,
132135
)
133136

134137

@@ -141,7 +144,7 @@ def _parse_audio_data(
141144
if isinstance(data, dict):
142145
return MiniCPMOAudioEmbeddingItems(
143146
data,
144-
fields_config=_minicpmo_field_config(data),
147+
fields_factory=_minicpmo_field_config,
145148
)
146149

147150
return super()._parse_audio_data(data)

vllm/model_executor/models/minicpmv.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,13 +255,16 @@ class MiniCPMVImageEmbeddingItems(DictEmbeddingItems):
255255
def __init__(
256256
self,
257257
data: Mapping[str, torch.Tensor],
258-
fields_config: Mapping[str, MultiModalFieldConfig],
258+
fields_factory: Callable[
259+
[Mapping[str, torch.Tensor]],
260+
Mapping[str, MultiModalFieldConfig],
261+
],
259262
) -> None:
260263
super().__init__(
261264
data,
262265
modality="image",
263-
fields_config=fields_config,
264266
required_fields={"image_embeds", "image_sizes"},
267+
fields_factory=fields_factory,
265268
)
266269

267270
def get_image_size(self, index: int) -> ImageSize:
@@ -274,13 +277,16 @@ class MiniCPMVVideoEmbeddingItems(DictEmbeddingItems):
274277
def __init__(
275278
self,
276279
data: Mapping[str, torch.Tensor],
277-
fields_config: Mapping[str, MultiModalFieldConfig],
280+
fields_factory: Callable[
281+
[Mapping[str, torch.Tensor]],
282+
Mapping[str, MultiModalFieldConfig],
283+
],
278284
) -> None:
279285
super().__init__(
280286
data,
281287
modality="video",
282-
fields_config=fields_config,
283288
required_fields={"video_embeds", "video_image_sizes"},
289+
fields_factory=fields_factory,
284290
)
285291

286292
def get_frame_size(self, index: int) -> ImageSize:
@@ -300,7 +306,7 @@ def _parse_image_data(
300306
if isinstance(data, dict):
301307
return MiniCPMVImageEmbeddingItems(
302308
data,
303-
fields_config=_minicpmv_field_config(data),
309+
fields_factory=_minicpmv_field_config,
304310
)
305311

306312
return super()._parse_image_data(data)
@@ -312,7 +318,7 @@ def _parse_video_data(
312318
if isinstance(data, dict):
313319
return MiniCPMVVideoEmbeddingItems(
314320
data,
315-
fields_config=_minicpmv_field_config(data),
321+
fields_factory=_minicpmv_field_config,
316322
)
317323

318324
return super()._parse_video_data(data)

vllm/model_executor/models/qwen2_vl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -691,8 +691,8 @@ def _parse_image_data(
691691
return DictEmbeddingItems(
692692
data,
693693
modality="image",
694-
fields_config=_qwen2vl_field_config(data),
695694
required_fields={"image_embeds", "image_grid_thw"},
695+
fields_factory=_qwen2vl_field_config,
696696
)
697697

698698
return super()._parse_image_data(data)
@@ -705,8 +705,8 @@ def _parse_video_data(
705705
return DictEmbeddingItems(
706706
data,
707707
modality="video",
708-
fields_config=_qwen2vl_field_config(data),
709708
required_fields={"video_embeds", "video_grid_thw"},
709+
fields_factory=_qwen2vl_field_config,
710710
)
711711

712712
return super()._parse_video_data(data)

vllm/multimodal/parse.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,24 +125,28 @@ def __init__(
125125
self,
126126
data: Mapping[str, torch.Tensor],
127127
modality: str,
128-
fields_config: Mapping[str, MultiModalFieldConfig],
129128
required_fields: set[str],
129+
fields_factory: Callable[
130+
[Mapping[str, torch.Tensor]],
131+
Mapping[str, MultiModalFieldConfig],
132+
],
130133
) -> None:
131134
super().__init__(data, modality)
132135

133-
missing_required_fields = required_fields - fields_config.keys()
134-
if missing_required_fields:
135-
fields = set(fields_config.keys())
136-
msg = f"{required_fields=} should be a subset of {fields=}"
137-
raise ValueError(msg)
138-
139136
missing_required_data_keys = required_fields - data.keys()
140137
if missing_required_data_keys:
141138
data_keys = set(data.keys())
142139
msg = (f"The data should contain the fields: {required_fields}, "
143140
f"but only found the following keys: {data_keys}")
144141
raise ValueError(msg)
145142

143+
fields_config = fields_factory(data)
144+
missing_required_fields = required_fields - fields_config.keys()
145+
if missing_required_fields:
146+
fields = set(fields_config.keys())
147+
msg = f"{required_fields=} should be a subset of {fields=}"
148+
raise ValueError(msg)
149+
146150
self.fields_config = fields_config
147151
self.required_fields = required_fields
148152

0 commit comments

Comments
 (0)