Skip to content

Commit 958432a

Browse files
davidxiaaarnphm
authored andcommitted
[Frontend] decrease import time of vllm.multimodal (vllm-project#18031)
Co-authored-by: Aaron Pham <Aaronpham0103@gmail.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
1 parent 73da97d commit 958432a

File tree

3 files changed

+45
-34
lines changed

3 files changed

+45
-34
lines changed

vllm/multimodal/inputs.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,43 @@
1010
Union, cast, final)
1111

1212
import numpy as np
13-
import torch
14-
import torch.types
15-
from PIL.Image import Image
16-
from transformers import BatchFeature
1713
from typing_extensions import NotRequired, TypeAlias
1814

1915
from vllm.jsontree import JSONTree, json_map_leaves
20-
from vllm.utils import full_groupby, is_list_of
16+
from vllm.utils import LazyLoader, full_groupby, is_list_of
2117

2218
if TYPE_CHECKING:
19+
import torch
20+
import torch.types
21+
from PIL.Image import Image
22+
from transformers.feature_extraction_utils import BatchFeature
23+
2324
from .hasher import MultiModalHashDict
25+
else:
26+
torch = LazyLoader("torch", globals(), "torch")
2427

2528
_T = TypeVar("_T")
2629

27-
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
30+
HfImageItem: TypeAlias = Union["Image", np.ndarray, "torch.Tensor"]
2831
"""
2932
A {class}`transformers.image_utils.ImageInput` representing a single image
3033
item, which can be passed to a HuggingFace `ImageProcessor`.
3134
"""
3235

33-
HfVideoItem: TypeAlias = Union[list[Image], np.ndarray, torch.Tensor,
34-
list[np.ndarray], list[torch.Tensor]]
36+
HfVideoItem: TypeAlias = Union[list["Image"], np.ndarray, "torch.Tensor",
37+
list[np.ndarray], list["torch.Tensor"]]
3538
"""
3639
A {class}`transformers.image_utils.VideoInput` representing a single video
3740
item, which can be passed to a HuggingFace `VideoProcessor`.
3841
"""
3942

40-
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, torch.Tensor]
43+
HfAudioItem: TypeAlias = Union[list[float], np.ndarray, "torch.Tensor"]
4144
"""
4245
Represents a single audio
4346
item, which can be passed to a HuggingFace `AudioProcessor`.
4447
"""
4548

46-
ImageItem: TypeAlias = Union[HfImageItem, torch.Tensor]
49+
ImageItem: TypeAlias = Union[HfImageItem, "torch.Tensor"]
4750
"""
4851
A {class}`transformers.image_utils.ImageInput` representing a single image
4952
item, which can be passed to a HuggingFace `ImageProcessor`.
@@ -53,7 +56,7 @@
5356
these are directly passed to the model without HF processing.
5457
"""
5558

56-
VideoItem: TypeAlias = Union[HfVideoItem, torch.Tensor]
59+
VideoItem: TypeAlias = Union[HfVideoItem, "torch.Tensor"]
5760
"""
5861
A {class}`transformers.image_utils.VideoInput` representing a single video
5962
item, which can be passed to a HuggingFace `VideoProcessor`.
@@ -64,7 +67,7 @@
6467
"""
6568

6669
AudioItem: TypeAlias = Union[HfAudioItem, tuple[np.ndarray, float],
67-
torch.Tensor]
70+
"torch.Tensor"]
6871
"""
6972
Represents a single audio
7073
item, which can be passed to a HuggingFace `AudioProcessor`.
@@ -132,7 +135,7 @@ class PlaceholderRange:
132135
length: int
133136
"""The length of the placeholder."""
134137

135-
is_embed: Optional[torch.Tensor] = None
138+
is_embed: Optional["torch.Tensor"] = None
136139
"""
137140
A boolean mask of shape `(length,)` indicating which positions
138141
between `offset` and `offset + length` to assign embeddings to.
@@ -158,8 +161,8 @@ def __eq__(self, other: object) -> bool:
158161
return nested_tensors_equal(self.is_embed, other.is_embed)
159162

160163

161-
NestedTensors = Union[list["NestedTensors"], list[torch.Tensor], torch.Tensor,
162-
tuple[torch.Tensor, ...]]
164+
NestedTensors: TypeAlias = Union[list["NestedTensors"], list["torch.Tensor"],
165+
"torch.Tensor", tuple["torch.Tensor", ...]]
163166
"""
164167
Uses a list instead of a tensor if the dimensions of each element do not match.
165168
"""
@@ -261,7 +264,7 @@ def build_elems(
261264
"""
262265
Construct {class}`MultiModalFieldElem` instances to represent
263266
the provided data.
264-
267+
265268
This is the inverse of {meth}`reduce_data`.
266269
"""
267270
raise NotImplementedError
@@ -422,7 +425,7 @@ def flat(modality: str,
422425
modality: The modality of the multi-modal item that uses this
423426
keyword argument.
424427
slices: For each multi-modal item, a slice (dim=0) or a tuple of
425-
slices (dim>0) that is used to extract the data corresponding
428+
slices (dim>0) that is used to extract the data corresponding
426429
to it.
427430
dim: The dimension to extract data, default to 0.
428431
@@ -465,7 +468,7 @@ def flat(modality: str,
465468

466469
@staticmethod
467470
def flat_from_sizes(modality: str,
468-
size_per_item: torch.Tensor,
471+
size_per_item: "torch.Tensor",
469472
dim: int = 0):
470473
"""
471474
Defines a field where an element in the batch is obtained by
@@ -602,7 +605,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
602605

603606
@staticmethod
604607
def from_hf_inputs(
605-
hf_inputs: BatchFeature,
608+
hf_inputs: "BatchFeature",
606609
config_by_key: Mapping[str, MultiModalFieldConfig],
607610
):
608611
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
@@ -792,7 +795,7 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
792795
return self._items_by_modality[modality]
793796

794797

795-
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
798+
MultiModalPlaceholderDict: TypeAlias = Mapping[str, Sequence[PlaceholderRange]]
796799
"""
797800
A dictionary containing placeholder ranges for each modality.
798801
"""
@@ -823,7 +826,7 @@ class MultiModalInputs(TypedDict):
823826
mm_hashes: Optional["MultiModalHashDict"]
824827
"""The hashes of the multi-modal data."""
825828

826-
mm_placeholders: MultiModalPlaceholderDict
829+
mm_placeholders: "MultiModalPlaceholderDict"
827830
"""
828831
For each modality, information about the placeholder tokens in
829832
`prompt_token_ids`.

vllm/multimodal/parse.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88

99
import numpy as np
1010
import torch
11-
from PIL.Image import Image
12-
from transformers import BatchFeature
1311
from typing_extensions import TypeAlias, TypeGuard, assert_never
1412

15-
from vllm.utils import is_list_of
13+
from vllm.utils import LazyLoader, is_list_of
1614

1715
from .audio import AudioResampler
1816
from .inputs import (AudioItem, HfAudioItem, HfImageItem, HfVideoItem,
@@ -22,6 +20,11 @@
2220
_T = TypeVar("_T")
2321
_I = TypeVar("_I")
2422

23+
if TYPE_CHECKING:
24+
import PIL.Image as PILImage
25+
else:
26+
PILImage = LazyLoader("PILImage", globals(), "PIL.Image")
27+
2528

2629
class ModalityDataItems(ABC, Generic[_T, _I]):
2730
"""
@@ -131,6 +134,8 @@ def __init__(
131134
Mapping[str, MultiModalFieldConfig],
132135
],
133136
) -> None:
137+
from transformers.feature_extraction_utils import BatchFeature
138+
134139
super().__init__(data, modality)
135140

136141
missing_required_data_keys = required_fields - data.keys()
@@ -200,7 +205,7 @@ def __init__(self, data: Sequence[HfImageItem]) -> None:
200205
def get_image_size(self, item_idx: int) -> ImageSize:
201206
image = self.get(item_idx)
202207

203-
if isinstance(image, Image):
208+
if isinstance(image, PILImage.Image):
204209
return ImageSize(*image.size)
205210
if isinstance(image, (np.ndarray, torch.Tensor)):
206211
_, h, w = image.shape
@@ -226,7 +231,7 @@ def get_num_frames(self, item_idx: int) -> int:
226231
def get_frame_size(self, item_idx: int) -> ImageSize:
227232
image = self.get(item_idx)[0] # Assume that the video isn't empty
228233

229-
if isinstance(image, Image):
234+
if isinstance(image, PILImage.Image):
230235
return ImageSize(*image.size)
231236
if isinstance(image, (np.ndarray, torch.Tensor)):
232237
_, h, w = image.shape
@@ -253,7 +258,7 @@ class MultiModalDataItems(UserDict[str, ModalityDataItems[Any, Any]]):
253258
def get_count(self, modality: str, *, strict: bool = True) -> int:
254259
"""
255260
Get the number of data items belonging to a modality.
256-
261+
257262
If `strict=False`, return `0` instead of raising {exc}`KeyError`
258263
even if the modality is not found.
259264
"""
@@ -399,7 +404,7 @@ def _parse_image_data(
399404
if self._is_embeddings(data):
400405
return ImageEmbeddingItems(data)
401406

402-
if (isinstance(data, Image)
407+
if (isinstance(data, PILImage.Image)
403408
or isinstance(data,
404409
(np.ndarray, torch.Tensor)) and data.ndim == 3):
405410
data_items = [data]
@@ -420,7 +425,7 @@ def _parse_video_data(
420425
if self._is_embeddings(data):
421426
return VideoEmbeddingItems(data)
422427

423-
if (is_list_of(data, Image)
428+
if (is_list_of(data, PILImage.Image)
424429
or isinstance(data,
425430
(np.ndarray, torch.Tensor)) and data.ndim == 4):
426431
data_items = [data]

vllm/multimodal/processing.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
TypeVar, Union, cast)
1414

1515
import torch
16-
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
1716
from typing_extensions import assert_never
1817

1918
from vllm.inputs import InputProcessingContext
@@ -31,6 +30,10 @@
3130
MultiModalDataParser)
3231

3332
if TYPE_CHECKING:
33+
from transformers.configuration_utils import PretrainedConfig
34+
from transformers.feature_extraction_utils import BatchFeature
35+
from transformers.processing_utils import ProcessorMixin
36+
3437
from .profiling import BaseDummyInputsBuilder
3538

3639
logger = init_logger(__name__)
@@ -1047,10 +1050,10 @@ def model_id(self) -> str:
10471050
def get_tokenizer(self) -> AnyTokenizer:
10481051
return self.ctx.tokenizer
10491052

1050-
def get_hf_config(self) -> PretrainedConfig:
1053+
def get_hf_config(self) -> "PretrainedConfig":
10511054
return self.ctx.get_hf_config()
10521055

1053-
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
1056+
def get_hf_processor(self, **kwargs: object) -> "ProcessorMixin":
10541057
"""
10551058
Subclasses can override this method to handle
10561059
specific kwargs from model config or user inputs.
@@ -1165,7 +1168,7 @@ def _to_mm_items(
11651168
@abstractmethod
11661169
def _get_mm_fields_config(
11671170
self,
1168-
hf_inputs: BatchFeature,
1171+
hf_inputs: "BatchFeature",
11691172
hf_processor_mm_kwargs: Mapping[str, object],
11701173
) -> Mapping[str, MultiModalFieldConfig]:
11711174
"""Given the HF-processed data, output the metadata of each field."""
@@ -1222,7 +1225,7 @@ def _call_hf_processor(
12221225
# This refers to the data to be passed to HF processor.
12231226
mm_data: Mapping[str, object],
12241227
mm_kwargs: Mapping[str, object],
1225-
) -> BatchFeature:
1228+
) -> "BatchFeature":
12261229
"""
12271230
Call the HF processor on the prompt text and
12281231
associated multi-modal data.

0 commit comments

Comments
 (0)