Skip to content

Commit ecd99b8

Browse files
DarkLight1337LeiWang1999
authored andcommitted
[Core] Rename input data types (vllm-project#8688)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 4c3f2ac commit ecd99b8

32 files changed

+438
-340
lines changed

docs/source/dev/input_processing/model_inputs_index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Module Contents
2525
LLM Engine Inputs
2626
-----------------
2727

28-
.. autoclass:: vllm.inputs.LLMInputs
28+
.. autoclass:: vllm.inputs.DecoderOnlyInputs
2929
:members:
3030
:show-inheritance:
3131

tests/models/decoder_only/vision_language/test_phi3v.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
22
import re
3-
from typing import Callable, List, Optional, Tuple, Type
3+
from typing import List, Optional, Tuple, Type
44

55
import pytest
66
import torch
77
from transformers import AutoImageProcessor, AutoTokenizer
88

9-
from vllm.inputs import InputContext, LLMInputs
9+
from vllm.inputs import InputContext, token_inputs
1010
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
1111
from vllm.multimodal import MultiModalRegistry
1212
from vllm.multimodal.utils import rescale_image_size
@@ -311,7 +311,7 @@ def test_input_mapper_override(model: str, image_assets: _ImageAssets,
311311
(4, 781),
312312
(16, 2653),
313313
])
314-
def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
314+
def test_max_tokens_override(get_max_phi3v_image_tokens, model: str,
315315
num_crops: int, expected_max_tokens: int):
316316
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
317317
# NOTE: mm_processor_kwargs on the context in this test is unused, since
@@ -343,8 +343,8 @@ def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
343343
(16, 2653, 1),
344344
(16, 2653, 2),
345345
])
346-
def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
347-
num_crops: int, toks_per_img: int, num_imgs: int):
346+
def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int,
347+
toks_per_img: int, num_imgs: int):
348348
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
349349
# Same as the previous test - don't initialize mm_processor_kwargs
350350
# in this test and assume that the kwargs will be correctly expanded by
@@ -374,7 +374,7 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
374374
(16, 1921, 1),
375375
(16, 1921, 2),
376376
])
377-
def test_input_processor_override(input_processor_for_phi3v: Callable,
377+
def test_input_processor_override(input_processor_for_phi3v,
378378
image_assets: _ImageAssets, model: str,
379379
num_crops: int, expected_toks_per_img: int,
380380
num_imgs: int):
@@ -393,16 +393,14 @@ def test_input_processor_override(input_processor_for_phi3v: Callable,
393393
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
394394
images = [image_assets[0].pil_image] * num_imgs
395395

396-
llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt),
397-
prompt=prompt,
398-
multi_modal_data={"image": images})
396+
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
397+
prompt=prompt,
398+
multi_modal_data={"image": images})
399399

400-
proc_llm_inputs = input_processor_for_phi3v(
401-
ctx=ctx,
402-
llm_inputs=llm_inputs,
403-
num_crops=num_crops,
404-
)
400+
processed_inputs = input_processor_for_phi3v(ctx,
401+
inputs,
402+
num_crops=num_crops)
405403

406404
# Ensure we have the right number of placeholders per num_crops size
407-
img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
405+
img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
408406
assert img_tok_count == expected_toks_per_img * num_imgs

tests/models/decoder_only/vision_language/test_qwen.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from PIL.Image import Image
77

8-
from vllm.inputs import InputContext, LLMInputs
8+
from vllm.inputs import InputContext, token_inputs
99
from vllm.multimodal.base import MultiModalInputs
1010
from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size
1111

@@ -71,12 +71,12 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen,
7171
"""Happy cases for image inputs to Qwen's multimodal input processor."""
7272
prompt = "".join(
7373
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
74-
inputs = LLMInputs(
74+
inputs = token_inputs(
7575
prompt=prompt,
7676
# When processing multimodal data for a multimodal model, the qwen
7777
# input processor will overwrite the provided prompt_token_ids with
7878
# the image prompts
79-
prompt_token_ids=None,
79+
prompt_token_ids=[],
8080
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
8181
)
8282
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
@@ -134,9 +134,9 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen,
134134
trust_remote_code=True)
135135
prompt = "Picture 1: <img></img>\n"
136136
prompt_token_ids = tokenizer.encode(prompt)
137-
inputs = LLMInputs(prompt=prompt,
138-
prompt_token_ids=prompt_token_ids,
139-
multi_modal_data=mm_data)
137+
inputs = token_inputs(prompt=prompt,
138+
prompt_token_ids=prompt_token_ids,
139+
multi_modal_data=mm_data)
140140
# Should fail since we have too many or too few dimensions for embeddings
141141
with pytest.raises(ValueError):
142142
input_processor_for_qwen(qwen_vl_context, inputs)

tests/multimodal/test_processor_kwargs.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
import torch
77

8-
from vllm.inputs import InputContext, LLMInputs
8+
from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs
99
from vllm.inputs.registry import InputRegistry
1010
from vllm.multimodal import MultiModalRegistry
1111
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
@@ -31,7 +31,7 @@ def use_processor_mock():
3131
"""Patches the internal model input processor with an override callable."""
3232

3333
def custom_processor(ctx: InputContext,
34-
llm_inputs: LLMInputs,
34+
inputs: DecoderOnlyInputs,
3535
*,
3636
num_crops=DEFAULT_NUM_CROPS):
3737
# For testing purposes, we don't worry about the llm inputs / return
@@ -84,7 +84,7 @@ def test_default_processor_is_a_noop():
8484
dummy_registry = InputRegistry()
8585
ctx = build_model_context(DUMMY_MODEL_ID)
8686
processor = dummy_registry.create_input_processor(ctx.model_config)
87-
proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
87+
proc_inputs = token_inputs(prompt_token_ids=[], prompt="")
8888
proc_outputs = processor(inputs=proc_inputs)
8989
assert proc_inputs is proc_outputs
9090

@@ -125,9 +125,9 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops,
125125
ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs)
126126
processor = dummy_registry.create_input_processor(ctx.model_config)
127127
num_crops_val = processor(
128-
LLMInputs(prompt_token_ids=[],
129-
prompt="",
130-
mm_processor_kwargs=inference_kwargs))
128+
token_inputs(prompt_token_ids=[],
129+
prompt="",
130+
mm_processor_kwargs=inference_kwargs))
131131
assert num_crops_val == expected_seq_count
132132

133133

@@ -154,9 +154,9 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock,
154154
processor = dummy_registry.create_input_processor(ctx.model_config)
155155
# Should filter out the inference time kwargs
156156
num_crops_val = processor(
157-
LLMInputs(prompt_token_ids=[],
158-
prompt="",
159-
mm_processor_kwargs=mm_processor_kwargs))
157+
token_inputs(prompt_token_ids=[],
158+
prompt="",
159+
mm_processor_kwargs=mm_processor_kwargs))
160160
assert num_crops_val == DEFAULT_NUM_CROPS
161161

162162

vllm/engine/llm_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
from vllm.executor.executor_base import ExecutorBase
3030
from vllm.executor.gpu_executor import GPUExecutor
3131
from vllm.executor.ray_utils import initialize_ray_cluster
32-
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
33-
InputRegistry, LLMInputs, PromptType)
32+
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
33+
EncoderDecoderInputs, InputRegistry, PromptType)
3434
from vllm.inputs.preprocess import InputPreprocessor
3535
from vllm.logger import init_logger
3636
from vllm.lora.request import LoRARequest
@@ -635,7 +635,7 @@ def _verify_args(self) -> None:
635635
def _add_processed_request(
636636
self,
637637
request_id: str,
638-
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
638+
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
639639
params: Union[SamplingParams, PoolingParams],
640640
arrival_time: float,
641641
lora_request: Optional[LoRARequest],
@@ -1855,8 +1855,8 @@ def is_encoder_decoder_model(self):
18551855
def is_embedding_model(self):
18561856
return self.model_config.is_embedding_model
18571857

1858-
def _validate_model_inputs(self, inputs: Union[LLMInputs,
1859-
EncoderDecoderLLMInputs]):
1858+
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
1859+
EncoderDecoderInputs]):
18601860
if self.model_config.is_multimodal_model:
18611861
# For encoder-decoder multimodal models, the max_prompt_len
18621862
# restricts the decoder prompt length

vllm/inputs/__init__.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
2-
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
3-
TokensPrompt, build_explicit_enc_dec_prompt,
4-
to_enc_dec_tuple_list, zip_enc_dec_prompts)
1+
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
2+
ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
3+
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
4+
build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
5+
token_inputs, zip_enc_dec_prompts)
56
from .registry import InputContext, InputRegistry
67

78
INPUT_REGISTRY = InputRegistry()
@@ -19,8 +20,11 @@
1920
"PromptType",
2021
"SingletonPrompt",
2122
"ExplicitEncoderDecoderPrompt",
22-
"LLMInputs",
23-
"EncoderDecoderLLMInputs",
23+
"TokenInputs",
24+
"token_inputs",
25+
"SingletonInputs",
26+
"DecoderOnlyInputs",
27+
"EncoderDecoderInputs",
2428
"build_explicit_enc_dec_prompt",
2529
"to_enc_dec_tuple_list",
2630
"zip_enc_dec_prompts",
@@ -31,14 +35,31 @@
3135

3236

3337
def __getattr__(name: str):
34-
if name == "PromptInput":
35-
import warnings
38+
import warnings
3639

40+
if name == "PromptInput":
3741
msg = ("PromptInput has been renamed to PromptType. "
3842
"The original name will be removed in an upcoming version.")
3943

4044
warnings.warn(DeprecationWarning(msg), stacklevel=2)
4145

4246
return PromptType
4347

48+
if name == "LLMInputs":
49+
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
50+
"The original name will be removed in an upcoming version.")
51+
52+
warnings.warn(DeprecationWarning(msg), stacklevel=2)
53+
54+
return DecoderOnlyInputs
55+
56+
if name == "EncoderDecoderLLMInputs":
57+
msg = (
58+
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
59+
"The original name will be removed in an upcoming version.")
60+
61+
warnings.warn(DeprecationWarning(msg), stacklevel=2)
62+
63+
return EncoderDecoderInputs
64+
4465
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

vllm/inputs/data.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
2-
Optional, Tuple, Union)
2+
Optional, Tuple, Union, cast)
33

44
from typing_extensions import NotRequired, TypedDict, TypeVar
55

@@ -51,7 +51,7 @@ class TokensPrompt(TypedDict):
5151

5252
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
5353
"""
54-
Set of possible schemas for a single LLM input:
54+
Set of possible schemas for a single prompt:
5555
5656
- A text prompt (:class:`str` or :class:`TextPrompt`)
5757
- A tokenized prompt (:class:`TokensPrompt`)
@@ -120,13 +120,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
120120
"""
121121

122122

123-
class LLMInputs(TypedDict):
124-
"""
125-
The inputs in :class:`~vllm.LLMEngine` before they are
126-
passed to the model executor.
127-
128-
This specifies the data required for decoder-only models.
129-
"""
123+
class TokenInputs(TypedDict):
124+
"""Represents token-based inputs."""
130125
prompt_token_ids: List[int]
131126
"""The token IDs of the prompt."""
132127

@@ -150,7 +145,40 @@ class LLMInputs(TypedDict):
150145
"""
151146

152147

153-
class EncoderDecoderLLMInputs(LLMInputs):
148+
def token_inputs(
149+
prompt_token_ids: List[int],
150+
prompt: Optional[str] = None,
151+
multi_modal_data: Optional["MultiModalDataDict"] = None,
152+
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
153+
) -> TokenInputs:
154+
"""Construct :class:`TokenInputs` from optional values."""
155+
inputs = TokenInputs(prompt_token_ids=prompt_token_ids)
156+
157+
if prompt is not None:
158+
inputs["prompt"] = prompt
159+
if multi_modal_data is not None:
160+
inputs["multi_modal_data"] = multi_modal_data
161+
if mm_processor_kwargs is not None:
162+
inputs["mm_processor_kwargs"] = mm_processor_kwargs
163+
164+
return inputs
165+
166+
167+
SingletonInputs = TokenInputs
168+
"""
169+
A processed :class:`SingletonPrompt` which can be passed to
170+
:class:`vllm.sequence.Sequence`.
171+
"""
172+
173+
DecoderOnlyInputs = TokenInputs
174+
"""
175+
The inputs in :class:`~vllm.LLMEngine` before they are
176+
passed to the model executor.
177+
This specifies the data required for decoder-only models.
178+
"""
179+
180+
181+
class EncoderDecoderInputs(TokenInputs):
154182
"""
155183
The inputs in :class:`~vllm.LLMEngine` before they are
156184
passed to the model executor.
@@ -204,11 +232,12 @@ def zip_enc_dec_prompts(
204232
be zipped with the encoder/decoder prompts.
205233
"""
206234
if mm_processor_kwargs is None:
207-
mm_processor_kwargs = {}
208-
if isinstance(mm_processor_kwargs, Dict):
235+
mm_processor_kwargs = cast(Dict[str, Any], {})
236+
if isinstance(mm_processor_kwargs, dict):
209237
return [
210-
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
211-
mm_processor_kwargs)
238+
build_explicit_enc_dec_prompt(
239+
encoder_prompt, decoder_prompt,
240+
cast(Dict[str, Any], mm_processor_kwargs))
212241
for (encoder_prompt,
213242
decoder_prompt) in zip(enc_prompts, dec_prompts)
214243
]
@@ -229,14 +258,31 @@ def to_enc_dec_tuple_list(
229258

230259

231260
def __getattr__(name: str):
232-
if name == "PromptInput":
233-
import warnings
261+
import warnings
234262

263+
if name == "PromptInput":
235264
msg = ("PromptInput has been renamed to PromptType. "
236265
"The original name will be removed in an upcoming version.")
237266

238267
warnings.warn(DeprecationWarning(msg), stacklevel=2)
239268

240269
return PromptType
241270

271+
if name == "LLMInputs":
272+
msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
273+
"The original name will be removed in an upcoming version.")
274+
275+
warnings.warn(DeprecationWarning(msg), stacklevel=2)
276+
277+
return DecoderOnlyInputs
278+
279+
if name == "EncoderDecoderLLMInputs":
280+
msg = (
281+
"EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
282+
"The original name will be removed in an upcoming version.")
283+
284+
warnings.warn(DeprecationWarning(msg), stacklevel=2)
285+
286+
return EncoderDecoderInputs
287+
242288
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

0 commit comments

Comments
 (0)