Skip to content

Commit ab93f13

Browse files
[VLM] Various cleanup and fixes (#14806)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 40253ba commit ab93f13

File tree

14 files changed

+283
-273
lines changed

14 files changed

+283
-273
lines changed

vllm/entrypoints/chat_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from vllm.logger import init_logger
3838
from vllm.multimodal import MultiModalDataDict
3939
from vllm.multimodal.utils import MediaConnector
40+
from vllm.transformers_utils.processor import cached_get_processor
4041
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
4142

4243
logger = init_logger(__name__)
@@ -1070,7 +1071,19 @@ def apply_hf_chat_template(
10701071
tokenize: bool = False, # Different from HF's default
10711072
**kwargs: Any,
10721073
) -> str:
1073-
if chat_template is None and tokenizer.chat_template is None:
1074+
if chat_template is None:
1075+
chat_template = tokenizer.chat_template
1076+
1077+
# FIXME: Temporary workaround for
1078+
# https://huggingface.co/mistral-community/pixtral-12b/discussions/31
1079+
if chat_template is None:
1080+
try:
1081+
processor = cached_get_processor(tokenizer.name_or_path)
1082+
chat_template = processor.chat_template
1083+
except Exception:
1084+
pass
1085+
1086+
if chat_template is None:
10741087
raise ValueError(
10751088
"As of transformers v4.44, default chat template is no longer "
10761089
"allowed, so you must provide a chat template if the tokenizer "

vllm/model_executor/models/fuyu.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
""" PyTorch Fuyu model."""
1919
import math
2020
from collections.abc import Iterable, Mapping, Sequence
21-
from typing import List, Literal, Optional, Set, Tuple, TypedDict
21+
from typing import Literal, Optional, Set, Tuple, TypedDict
2222

2323
import torch
2424
import torch.nn as nn
@@ -31,8 +31,7 @@
3131
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
3232
from vllm.model_executor.sampling_metadata import SamplingMetadata
3333
from vllm.multimodal import MULTIMODAL_REGISTRY
34-
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
35-
NestedTensors)
34+
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
3635
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
3736
MultiModalDataItems)
3837
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -58,10 +57,12 @@ class FuyuImagePatchInputs(TypedDict):
5857
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
5958
"""
6059

61-
patches_per_image: List[int]
60+
patches_per_image: list[int]
6261
"""
63-
List of number of total patches for each image in the batch.
64-
This is used to restore the first two dimensions of `flat_data`.
62+
The number of total patches for each image in the batch.
63+
64+
This is used to split the embeddings which has the first two dimensions
65+
flattened just like `flat_data`.
6566
"""
6667

6768

@@ -317,7 +318,7 @@ def _parse_and_validate_image_input(
317318
return None
318319

319320
def _process_image_input(
320-
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
321+
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
321322
image_patches_flat = image_input["flat_data"]
322323
patches_per_image = image_input["patches_per_image"]
323324

vllm/model_executor/models/interfaces.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77
from torch import Tensor
8-
from typing_extensions import TypeIs
8+
from typing_extensions import Self, TypeIs
99

1010
from vllm.logger import init_logger
1111
from vllm.model_executor.layers.quantization.base_config import (
@@ -451,7 +451,7 @@ class SupportsQuant:
451451
packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
452452
quant_config: Optional[QuantizationConfig] = None
453453

454-
def __new__(cls, *args, **kwargs) -> "SupportsQuant":
454+
def __new__(cls, *args, **kwargs) -> Self:
455455
instance = super().__new__(cls)
456456
quant_config = cls._find_quant_config(*args, **kwargs)
457457
if quant_config is not None:

vllm/model_executor/models/llava.py

Lines changed: 64 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from abc import abstractmethod
44
from collections.abc import Iterable, Mapping, Sequence
55
from functools import cached_property
6-
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
7-
TypedDict, TypeVar, Union, cast)
6+
from typing import (Final, Literal, Optional, Protocol, Set, Tuple, TypedDict,
7+
TypeVar, Union, cast)
88

99
import torch
1010
import torch.nn as nn
@@ -39,8 +39,7 @@
3939

4040
from .clip import CLIPVisionModel
4141
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
42-
from .pixtral import (PixtralHFVisionModel,
43-
get_pixtral_hf_image_feature_grid_size)
42+
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
4443
from .siglip import SiglipVisionModel
4544
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
4645
maybe_prefix, merge_multimodal_embeddings)
@@ -49,31 +48,42 @@
4948

5049
class LlavaImagePixelInputs(TypedDict):
5150
type: Literal["pixel_values"]
52-
data: Union[torch.Tensor, List[torch.Tensor]]
51+
pixel_values: torch.Tensor
5352
"""
5453
Shape: `(batch_size * num_images, num_channels, height, width)`
5554
5655
Note that `height` or `width` may be different per batch and image,
5756
in which case the data is passed as a list instead of a batched tensor.
5857
"""
5958

60-
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
59+
60+
class PixtralHFImagePixelInputs(TypedDict):
61+
type: Literal["pixel_values_pixtral"]
62+
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
63+
"""
64+
Shape: `(batch_size * num_images, num_channels, height, width)`
65+
66+
Note that `height` or `width` may be different per batch and image,
67+
in which case the data is passed as a list instead of a batched tensor.
68+
"""
69+
70+
feat_is_patch: Union[torch.Tensor, list[torch.Tensor]]
6171
"""
6272
A boolean mask indicating which image features correspond
6373
to patch tokens.
6474
6575
Shape: `(batch_size, num_crops, num_patch)`
6676
"""
6777

68-
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
78+
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
6979
"""
7080
A boolean mask indicating which image embeddings correspond
7181
to patch tokens.
7282
7383
Shape: `(batch_size, num_embeds)`
7484
"""
7585

76-
num_crops: torch.Tensor
86+
num_crops: Union[torch.Tensor, list[torch.Tensor]]
7787
"""Shape: `(batch_size, num_images)`"""
7888

7989

@@ -85,27 +95,9 @@ class LlavaImageEmbeddingInputs(TypedDict):
8595
`hidden_size` must match the hidden size of language model backbone.
8696
"""
8797

88-
feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
89-
"""
90-
A boolean mask indicating which image features correspond
91-
to patch tokens.
92-
93-
Shape: `(batch_size, num_crops, num_patch)`
94-
"""
95-
96-
embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
97-
"""
98-
A boolean mask indicating which image embeddings correspond
99-
to patch tokens.
100-
101-
Shape: `(batch_size, num_embeds)`
102-
"""
103-
104-
num_crops: torch.Tensor
105-
"""Shape: `(batch_size, num_images)`"""
106-
10798

108-
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
99+
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
100+
LlavaImageEmbeddingInputs]
109101

110102

111103
class LlavaMultiModalProjector(nn.Module):
@@ -357,13 +349,15 @@ def _call_hf_processor(
357349
]
358350

359351
hf_config = self.info.get_hf_config()
352+
vision_config = hf_config.vision_config
353+
assert isinstance(vision_config, PixtralVisionConfig)
354+
encoder_info = PixtralHFEncoderInfo(vision_config)
360355

361356
tile_sizes = [
362-
get_pixtral_hf_image_feature_grid_size(
363-
hf_config.vision_config,
357+
encoder_info.get_patch_grid_size(
364358
image_width=pixel_value.shape[-1],
365-
image_height=pixel_value.shape[-2])
366-
for pixel_value in processed_outputs["pixel_values"]
359+
image_height=pixel_value.shape[-2],
360+
) for pixel_value in processed_outputs["pixel_values"]
367361
]
368362
num_crops = torch.tensor([(ncols + 1) * nrows
369363
for ncols, nrows in tile_sizes])
@@ -411,13 +405,13 @@ def _get_prompt_updates(
411405

412406
vision_config = hf_config.vision_config
413407
assert isinstance(vision_config, PixtralVisionConfig)
408+
encoder_info = PixtralHFEncoderInfo(vision_config)
414409

415410
def get_replacement(item_idx: int):
416411
images = mm_items.get_items("image", ImageProcessorItems)
417412
image_size = images.get_image_size(item_idx)
418413

419-
ncols, nrows = get_pixtral_hf_image_feature_grid_size(
420-
vision_config,
414+
ncols, nrows = encoder_info.get_patch_grid_size(
421415
image_width=image_size.width,
422416
image_height=image_size.height,
423417
)
@@ -512,7 +506,7 @@ def init_vision_tower_for_llava(
512506
*,
513507
require_post_norm: Optional[bool] = None,
514508
prefix: str = "",
515-
):
509+
) -> Union[CLIPVisionModel, SiglipVisionModel, PixtralHFVisionModel]:
516510
vision_config = hf_config.vision_config
517511

518512
# Initialize the vision tower only up to the deepest required feature layer
@@ -627,57 +621,52 @@ def _parse_and_validate_image_input(
627621
if pixel_values is None and image_embeds is None:
628622
return None
629623

630-
feat_is_patch = kwargs.pop("feat_is_patch", None)
631-
if feat_is_patch is not None and not isinstance(
632-
feat_is_patch, (torch.Tensor, list)):
633-
raise ValueError("Incorrect type of feat_is_patch. "
634-
f"Got type: {type(feat_is_patch)}")
635-
636-
embed_is_patch = kwargs.pop("embed_is_patch", None)
637-
if embed_is_patch is not None and not isinstance(
638-
embed_is_patch, (torch.Tensor, list)):
639-
raise ValueError("Incorrect type of embed_is_patch. "
640-
f"Got type: {type(embed_is_patch)}")
641-
642-
num_crops = kwargs.pop("num_crops", None)
643-
if num_crops is not None and not isinstance(num_crops, torch.Tensor):
644-
raise ValueError("Incorrect type of num_crops. "
645-
f"Got type: {type(num_crops)}")
646-
647624
if pixel_values is not None:
648625
if not isinstance(pixel_values, (torch.Tensor, list)):
649626
raise ValueError("Incorrect type of pixel values. "
650627
f"Got type: {type(pixel_values)}")
651628

652629
if self.config.vision_config.model_type == "pixtral":
653-
return LlavaImagePixelInputs(
654-
type="pixel_values",
655-
data=flatten_bn(pixel_values),
630+
feat_is_patch = kwargs.pop("feat_is_patch")
631+
if not isinstance(feat_is_patch, (torch.Tensor, list)):
632+
raise ValueError("Incorrect type of feat_is_patch. "
633+
f"Got type: {type(feat_is_patch)}")
634+
635+
embed_is_patch = kwargs.pop("embed_is_patch")
636+
if not isinstance(embed_is_patch, (torch.Tensor, list)):
637+
raise ValueError("Incorrect type of embed_is_patch. "
638+
f"Got type: {type(embed_is_patch)}")
639+
640+
num_crops = kwargs.pop("num_crops")
641+
if not isinstance(num_crops, (torch.Tensor, list)):
642+
raise ValueError("Incorrect type of num_crops. "
643+
f"Got type: {type(num_crops)}")
644+
645+
return PixtralHFImagePixelInputs(
646+
type="pixel_values_pixtral",
647+
pixel_values=flatten_bn(pixel_values),
656648
feat_is_patch=feat_is_patch,
657649
embed_is_patch=embed_is_patch,
658650
num_crops=num_crops,
659651
)
660652

661653
return LlavaImagePixelInputs(
662654
type="pixel_values",
663-
data=self._validate_pixel_values(
655+
pixel_values=self._validate_pixel_values(
664656
flatten_bn(pixel_values, concat=True)),
665-
feat_is_patch=feat_is_patch,
666-
embed_is_patch=embed_is_patch,
667-
num_crops=num_crops,
668657
)
669658

670659
if image_embeds is not None:
671660
if not isinstance(image_embeds, (torch.Tensor, list)):
672661
raise ValueError("Incorrect type of image embeddings. "
673662
f"Got type: {type(image_embeds)}")
674663

664+
if self.config.vision_config.model_type == "pixtral":
665+
raise ValueError("Pixtral-HF does not support image_embeds.")
666+
675667
return LlavaImageEmbeddingInputs(
676668
type="image_embeds",
677669
data=flatten_bn(image_embeds, concat=True),
678-
feat_is_patch=feat_is_patch,
679-
embed_is_patch=embed_is_patch,
680-
num_crops=num_crops,
681670
)
682671

683672
raise AssertionError("This line should be unreachable.")
@@ -696,7 +685,7 @@ def _image_pixels_to_features(
696685
self,
697686
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
698687
PixtralHFVisionModel],
699-
pixel_values: torch.Tensor,
688+
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
700689
) -> torch.Tensor:
701690

702691
# NOTE: we skip the step to select the vision feature layer since
@@ -708,17 +697,20 @@ def _image_pixels_to_features(
708697
strategy=self.config.vision_feature_select_strategy,
709698
)
710699

711-
def _process_image_pixels(self,
712-
inputs: LlavaImagePixelInputs) -> torch.Tensor:
700+
def _process_image_pixels(
701+
self,
702+
inputs: Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs],
703+
) -> torch.Tensor:
713704
assert self.vision_tower is not None
714705

715-
pixel_values = inputs["data"]
706+
pixel_values = inputs["pixel_values"]
716707

717708
return self._image_pixels_to_features(self.vision_tower, pixel_values)
718709

719-
def _process_image_input(self,
720-
image_input: LlavaImageInputs) -> torch.Tensor:
721-
710+
def _process_image_input(
711+
self,
712+
image_input: LlavaImageInputs,
713+
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
722714
if image_input["type"] == "image_embeds":
723715
return image_input["data"]
724716

@@ -783,11 +775,11 @@ def get_multimodal_embeddings(
783775
image_input = self._parse_and_validate_image_input(**kwargs)
784776
if image_input is None:
785777
return None
778+
786779
vision_embeddings = self._process_image_input(image_input)
787780

788-
if kwargs.get("v0_path", False) or \
789-
image_input.get("feat_is_patch") is None or \
790-
image_input.get("embed_is_patch") is None:
781+
if (kwargs.get("v0_path", False)
782+
or image_input["type"] != "pixel_values_pixtral"):
791783
# The path is used for pixtral (V0 only) and llava (V0/V1)
792784
return vision_embeddings
793785

vllm/model_executor/models/llava_next.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
class LlavaNextImagePixelInputs(TypedDict):
3434
type: Literal["pixel_values"]
35-
data: Union[torch.Tensor, List[torch.Tensor]]
35+
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
3636
"""
3737
Shape:
3838
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
@@ -315,7 +315,8 @@ def _parse_and_validate_image_input(
315315

316316
return LlavaNextImagePixelInputs(
317317
type="pixel_values",
318-
data=self._validate_pixel_values(flatten_bn(pixel_values)),
318+
pixel_values=self._validate_pixel_values(
319+
flatten_bn(pixel_values)),
319320
image_sizes=self._validate_image_sizes(
320321
flatten_bn(image_sizes, concat=True)),
321322
)
@@ -434,7 +435,7 @@ def _process_image_pixels(
434435
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
435436
assert self.vision_tower is not None
436437

437-
pixel_values = inputs["data"]
438+
pixel_values = inputs["pixel_values"]
438439

439440
if isinstance(pixel_values, torch.Tensor):
440441
b, num_patches, c, h, w = pixel_values.shape

0 commit comments

Comments
 (0)