|
5 | 5 | from collections.abc import Iterable, Mapping, Sequence |
6 | 6 | from dataclasses import dataclass, fields |
7 | 7 | from functools import cached_property |
8 | | -from typing import Literal, Optional, TypedDict, Union |
| 8 | +from typing import Annotated, Literal, Optional, Union |
9 | 9 |
|
10 | 10 | import torch |
11 | 11 | import torch.nn as nn |
|
48 | 48 | from vllm.sequence import IntermediateTensors |
49 | 49 | from vllm.transformers_utils.tokenizer import (MistralTokenizer, |
50 | 50 | cached_tokenizer_from_config) |
| 51 | +from vllm.utils.tensor_schema import TensorSchema, TensorShape |
51 | 52 |
|
52 | 53 | from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP |
53 | 54 | from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, |
|
68 | 69 | PATCH_MERGE = "patch_merge" |
69 | 70 |
|
70 | 71 |
|
71 | | -class PixtralImagePixelInputs(TypedDict): |
72 | | - type: Literal["pixel_values"] |
73 | | - |
74 | | - images: Union[torch.Tensor, list[torch.Tensor]] |
| 72 | +class PixtralImagePixelInputs(TensorSchema): |
75 | 73 | """ |
76 | | - Shape: `(batch_size * num_images, num_channels, image_width, image_height)` |
77 | | -
|
| 74 | + Dimensions: |
| 75 | + - bn: Batch size * number of images |
| 76 | + - c: Number of channels (3) |
| 77 | + - h: Height of each image |
| 78 | + - w: Width of each image |
| 79 | + |
78 | 80 | The result of stacking `ImageEncoding.tokens` from each prompt. |
79 | 81 | """ |
| 82 | + type: Literal["pixel_values"] = "pixel_values" |
| 83 | + |
| 84 | + images: Annotated[Union[torch.Tensor, list[torch.Tensor]], |
| 85 | + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})] |
80 | 86 |
|
81 | 87 |
|
82 | 88 | class PixtralProcessorAdapter: |
@@ -381,10 +387,6 @@ def _parse_and_validate_image_input( |
381 | 387 | if images is None: |
382 | 388 | return None |
383 | 389 |
|
384 | | - if not isinstance(images, (torch.Tensor, list)): |
385 | | - raise ValueError("Incorrect type of images. " |
386 | | - f"Got type: {type(images)}") |
387 | | - |
388 | 390 | return PixtralImagePixelInputs( |
389 | 391 | type="pixel_values", |
390 | 392 | images=flatten_bn(images), |
|
0 commit comments