Skip to content

Commit 74db351

Browse files
bbeckcaepwalsh
authored andcommitted
Migrate Pixtral inputs to TensorSchema (vllm-project#23472)
Signed-off-by: Benji Beck <benjibeck@meta.com>
1 parent 2674b66 commit 74db351

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

vllm/model_executor/models/pixtral.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Iterable, Mapping, Sequence
66
from dataclasses import dataclass, fields
77
from functools import cached_property
8-
from typing import Literal, Optional, TypedDict, Union
8+
from typing import Annotated, Literal, Optional, Union
99

1010
import torch
1111
import torch.nn as nn
@@ -48,6 +48,7 @@
4848
from vllm.sequence import IntermediateTensors
4949
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
5050
cached_tokenizer_from_config)
51+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
5152

5253
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
5354
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
@@ -68,15 +69,20 @@
6869
PATCH_MERGE = "patch_merge"
6970

7071

71-
class PixtralImagePixelInputs(TypedDict):
72-
type: Literal["pixel_values"]
73-
74-
images: Union[torch.Tensor, list[torch.Tensor]]
72+
class PixtralImagePixelInputs(TensorSchema):
7573
"""
76-
Shape: `(batch_size * num_images, num_channels, image_width, image_height)`
77-
74+
Dimensions:
75+
- bn: Batch size * number of images
76+
- c: Number of channels (3)
77+
- h: Height of each image
78+
- w: Width of each image
79+
7880
The result of stacking `ImageEncoding.tokens` from each prompt.
7981
"""
82+
type: Literal["pixel_values"] = "pixel_values"
83+
84+
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
85+
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})]
8086

8187

8288
class PixtralProcessorAdapter:
@@ -381,10 +387,6 @@ def _parse_and_validate_image_input(
381387
if images is None:
382388
return None
383389

384-
if not isinstance(images, (torch.Tensor, list)):
385-
raise ValueError("Incorrect type of images. "
386-
f"Got type: {type(images)}")
387-
388390
return PixtralImagePixelInputs(
389391
type="pixel_values",
390392
images=flatten_bn(images),

0 commit comments

Comments
 (0)