Skip to content

Commit a13a962

Browse files
petersalasrichardsliu
authored andcommitted
[Core][VLM] Add precise multi-modal placeholder tracking (vllm-project#8346)
Signed-off-by: Peter Salas <peter@fixie.ai> Signed-off-by: Richard Liu <ricliu@google.com>
1 parent 7655274 commit a13a962

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+914
-282
lines changed

examples/offline_inference_audio_language.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,7 @@ def run_ultravox(question: str, audio_count: int):
3434
tokenize=False,
3535
add_generation_prompt=True)
3636

37-
llm = LLM(model=model_name,
38-
enforce_eager=True,
39-
enable_chunked_prefill=False,
40-
max_model_len=8192,
41-
limit_mm_per_prompt={"audio": audio_count})
37+
llm = LLM(model=model_name, limit_mm_per_prompt={"audio": audio_count})
4238
stop_token_ids = None
4339
return llm, prompt, stop_token_ids
4440

tests/kernels/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ def make_test_metadata(
869869
return attn_backend.make_metadata(
870870
num_prefills=num_prefills,
871871
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
872+
multi_modal_placeholder_index_maps=None,
872873
num_prefill_tokens=num_prefill_tokens,
873874
num_decode_tokens=num_decode_tokens,
874875
seq_lens=seq_lens,
@@ -914,6 +915,7 @@ def make_test_metadata(
914915
return attn_backend.make_metadata(
915916
num_prefills=num_prefills,
916917
slot_mapping=kv_mmap.slot_mapping,
918+
multi_modal_placeholder_index_maps=None,
917919
num_prefill_tokens=num_prefill_tokens,
918920
num_decode_tokens=num_decode_tokens,
919921
seq_lens=seq_lens,

tests/models/decoder_only/audio_language/test_ultravox.py

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import numpy as np
44
import pytest
5+
import pytest_asyncio
56
from transformers import AutoModel, AutoTokenizer, BatchEncoding
67

8+
from tests.utils import RemoteOpenAIServer
79
from vllm.sequence import SampleLogprobs
810
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
911

@@ -17,6 +19,13 @@
1719
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
1820
HF_PLACEHOLDER = "<|audio|>"
1921

22+
CHUNKED_PREFILL_KWARGS = {
23+
"enable_chunked_prefill": True,
24+
"max_num_seqs": 2,
25+
# Use a very small limit to exercise chunked prefill.
26+
"max_num_batched_tokens": 16
27+
}
28+
2029

2130
@pytest.fixture(scope="session")
2231
def audio_assets():
@@ -30,6 +39,26 @@ def audio(request):
3039
return AudioAsset(request.param)
3140

3241

42+
@pytest.fixture(params=({}, CHUNKED_PREFILL_KWARGS))
43+
def server(request, audio_assets):
44+
args = [
45+
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
46+
f"--limit-mm-per-prompt=audio={len(audio_assets)}"
47+
] + [
48+
f"--{key.replace('_','-')}={value}"
49+
for key, value in request.param.items()
50+
]
51+
52+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
53+
yield remote_server
54+
55+
56+
@pytest_asyncio.fixture
57+
async def client(server):
58+
async with server.get_async_client() as async_client:
59+
yield async_client
60+
61+
3362
def _get_prompt(audio_count, question, placeholder):
3463
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
3564
placeholder = f"{placeholder}\n" * audio_count
@@ -68,8 +97,7 @@ def run_test(
6897
dtype: str,
6998
max_tokens: int,
7099
num_logprobs: int,
71-
tensor_parallel_size: int,
72-
distributed_executor_backend: Optional[str] = None,
100+
**kwargs,
73101
):
74102
"""Inference result should be the same between hf and vllm."""
75103
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
@@ -79,11 +107,8 @@ def run_test(
79107
# if we run HF first, the cuda initialization will be done and it
80108
# will hurt multiprocessing backend with fork method (the default method).
81109

82-
with vllm_runner(model,
83-
dtype=dtype,
84-
tensor_parallel_size=tensor_parallel_size,
85-
distributed_executor_backend=distributed_executor_backend,
86-
enforce_eager=True) as vllm_model:
110+
with vllm_runner(model, dtype=dtype, enforce_eager=True,
111+
**kwargs) as vllm_model:
87112
vllm_outputs_per_audio = [
88113
vllm_model.generate_greedy_logprobs([vllm_prompt],
89114
max_tokens,
@@ -135,18 +160,16 @@ def run_multi_audio_test(
135160
dtype: str,
136161
max_tokens: int,
137162
num_logprobs: int,
138-
tensor_parallel_size: int,
139-
distributed_executor_backend: Optional[str] = None,
163+
**kwargs,
140164
):
141165
with vllm_runner(model,
142166
dtype=dtype,
143-
tensor_parallel_size=tensor_parallel_size,
144-
distributed_executor_backend=distributed_executor_backend,
145167
enforce_eager=True,
146168
limit_mm_per_prompt={
147169
"audio":
148170
max((len(audio) for _, audio in prompts_and_audios))
149-
}) as vllm_model:
171+
},
172+
**kwargs) as vllm_model:
150173
vllm_outputs = vllm_model.generate_greedy_logprobs(
151174
[prompt for prompt, _ in prompts_and_audios],
152175
max_tokens,
@@ -162,8 +185,9 @@ def run_multi_audio_test(
162185
@pytest.mark.parametrize("dtype", ["half"])
163186
@pytest.mark.parametrize("max_tokens", [128])
164187
@pytest.mark.parametrize("num_logprobs", [5])
188+
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
165189
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
166-
num_logprobs: int) -> None:
190+
num_logprobs: int, vllm_kwargs: dict) -> None:
167191

168192
vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
169193
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
@@ -175,17 +199,18 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
175199
dtype=dtype,
176200
max_tokens=max_tokens,
177201
num_logprobs=num_logprobs,
178-
tensor_parallel_size=1,
202+
**vllm_kwargs,
179203
)
180204

181205

182206
@pytest.mark.core_model
183207
@pytest.mark.parametrize("dtype", ["half"])
184208
@pytest.mark.parametrize("max_tokens", [128])
185209
@pytest.mark.parametrize("num_logprobs", [5])
210+
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
186211
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
187-
max_tokens: int,
188-
num_logprobs: int) -> None:
212+
max_tokens: int, num_logprobs: int,
213+
vllm_kwargs: dict) -> None:
189214

190215
vllm_prompt = _get_prompt(len(audio_assets),
191216
"Describe each of the audios above.",
@@ -198,5 +223,37 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
198223
dtype=dtype,
199224
max_tokens=max_tokens,
200225
num_logprobs=num_logprobs,
201-
tensor_parallel_size=1,
226+
**vllm_kwargs,
202227
)
228+
229+
230+
@pytest.mark.asyncio
231+
async def test_online_inference(client, audio_assets):
232+
"""Exercises online inference with/without chunked prefill enabled."""
233+
234+
messages = [{
235+
"role":
236+
"user",
237+
"content": [
238+
*[{
239+
"type": "audio_url",
240+
"audio_url": {
241+
"url": audio.url
242+
}
243+
} for audio in audio_assets],
244+
{
245+
"type":
246+
"text",
247+
"text":
248+
f"What's happening in these {len(audio_assets)} audio clips?"
249+
},
250+
],
251+
}]
252+
253+
chat_completion = await client.chat.completions.create(model=MODEL_NAME,
254+
messages=messages,
255+
max_tokens=10)
256+
257+
assert len(chat_completion.choices) == 1
258+
choice = chat_completion.choices[0]
259+
assert choice.finish_reason == "length"

tests/multimodal/test_processor_kwargs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import pytest
66
import torch
77

8-
from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs
9-
from vllm.inputs.registry import InputRegistry
8+
from vllm.inputs import (DecoderOnlyInputs, DummyData, InputContext,
9+
InputRegistry, token_inputs)
1010
from vllm.multimodal import MultiModalRegistry
1111
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
1212

@@ -56,7 +56,7 @@ def custom_dummy_data_factory(self,
5656
num_crops=DEFAULT_NUM_CROPS):
5757
seq_data = SequenceData(
5858
array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops))
59-
return seq_data, None
59+
return DummyData(seq_data, None)
6060

6161
with patch(
6262
"vllm.inputs.registry.InputRegistry._default_dummy_data_factory",
@@ -177,9 +177,9 @@ def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
177177
# NOTE: seq_len is thrown away here since this will leverage the
178178
# default dummy data factory that we have patched in, whose seq
179179
# len is solely dependent on the value of the mm_processor_kwargs.
180-
seq_data, _ = dummy_registry.dummy_data_for_profiling(
180+
dummy_data = dummy_registry.dummy_data_for_profiling(
181181
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
182-
assert len(seq_data.prompt_token_ids) == expected_seq_count
182+
assert len(dummy_data.seq_data.prompt_token_ids) == expected_seq_count
183183

184184

185185
@pytest.mark.parametrize(
@@ -206,9 +206,9 @@ def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock,
206206
# NOTE: seq_len is thrown away here since this will leverage the
207207
# default dummy data factory that we have patched in, whose seq
208208
# len is solely dependent on the value of the mm_processor_kwargs.
209-
seq_data, _ = dummy_registry.dummy_data_for_profiling(
209+
dummy_data = dummy_registry.dummy_data_for_profiling(
210210
ctx.model_config, seq_len=-1, mm_registry=mm_registry)
211-
assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
211+
assert len(dummy_data.seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
212212

213213

214214
### Test overrides for the max token count per multimodal instance

tests/multimodal/test_utils.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,50 @@ def test_repeat_and_pad_placeholder_tokens(model):
9292
tokenizer = AutoTokenizer.from_pretrained(model)
9393

9494
test_cases = [
95-
("<image>", 2, "<image><image>", [32000, 32000]),
96-
("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
97-
("<image><image>", [3, 2], "<image><image><image><image><image>",
98-
[32000, 32000, 32000, 32000, 32000]),
99-
("Image:<image>Image:<image>!", [3, 2],
100-
"Image:<image><image><image>Image:<image><image>!",
101-
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
102-
("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
103-
]
104-
105-
for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
106-
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
95+
(
96+
"<image>",
97+
2,
98+
"<image><image>",
99+
[32000, 32000],
100+
[{ "offset": 0, "length": 2 }],
101+
),
102+
(
103+
"<image><image>",
104+
2,
105+
"<image><image><image>",
106+
[32000, 32000, 32000],
107+
[{ "offset": 0, "length": 2 }]),
108+
(
109+
"<image><image>",
110+
[3, 2],
111+
"<image><image><image><image><image>",
112+
[32000, 32000, 32000, 32000, 32000],
113+
[{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }],
114+
),
115+
(
116+
"Image:<image>Image:<image>!",
117+
[3, 2],
118+
"Image:<image><image><image>Image:<image><image>!",
119+
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
120+
[{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }],
121+
),
122+
(
123+
"<image>",
124+
[3, 2],
125+
"<image><image><image>",
126+
[32000, 32000, 32000],
127+
[{ "offset": 0, "length": 3 }],
128+
),
129+
] # yapf: disable
130+
131+
for (
132+
prompt,
133+
repeat_count,
134+
expected_prompt,
135+
expected_token_ids,
136+
expected_ranges,
137+
) in test_cases:
138+
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
107139
tokenizer=tokenizer,
108140
prompt=prompt,
109141
prompt_token_ids=tokenizer.encode(prompt,
@@ -113,3 +145,4 @@ def test_repeat_and_pad_placeholder_tokens(model):
113145
)
114146
assert new_prompt == expected_prompt
115147
assert new_token_ids == expected_token_ids
148+
assert ranges == expected_ranges

tests/worker/test_model_input.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def test_model_runner_input():
7373
num_prefill_tokens=2,
7474
num_decode_tokens=3,
7575
slot_mapping=torch.zeros(1),
76+
multi_modal_placeholder_index_maps=None,
7677
)
7778
model_input = ModelInputForGPUWithSamplingMetadata(
7879
input_tokens=torch.ones(10),
@@ -124,6 +125,7 @@ def test_embedding_model_runner_input():
124125
num_prefill_tokens=2,
125126
num_decode_tokens=3,
126127
slot_mapping=torch.zeros(1),
128+
multi_modal_placeholder_index_maps=None,
127129
)
128130
model_input = ModelInputForGPUWithPoolingMetadata(
129131
input_tokens=torch.ones(10),
@@ -174,6 +176,7 @@ def test_multi_step_model_runner_input():
174176
num_prefill_tokens=2,
175177
num_decode_tokens=3,
176178
slot_mapping=torch.zeros(1),
179+
multi_modal_placeholder_index_maps=None,
177180
)
178181
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
179182
input_tokens=torch.ones(10),

vllm/attention/backends/abstract.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import torch
99

10+
from vllm.multimodal import MultiModalPlaceholderMap
11+
1012
if TYPE_CHECKING:
1113
from vllm.worker.model_runner_base import (ModelRunnerBase,
1214
ModelRunnerInputBase,
@@ -108,6 +110,15 @@ class AttentionMetadata:
108110
# in block 0, and 1st slot in block 1, respectively.
109111
slot_mapping: torch.Tensor
110112

113+
# The index maps that relate multi-modal embeddings to the corresponding
114+
# placeholders.
115+
#
116+
# N.B. These aren't really related to attention and don't belong on this
117+
# type -- this is just a temporary solution to make them available to
118+
# `model_executable`.
119+
multi_modal_placeholder_index_maps: Optional[Dict[
120+
str, MultiModalPlaceholderMap.IndexMap]]
121+
111122
@property
112123
@abstractmethod
113124
def prefill_metadata(self) -> Optional["AttentionMetadata"]:

vllm/attention/backends/blocksparse_attn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ def prefill_metadata(
215215
num_prefill_tokens=self.num_prefill_tokens,
216216
num_decode_tokens=0,
217217
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
218+
multi_modal_placeholder_index_maps=self.
219+
multi_modal_placeholder_index_maps,
218220
seq_lens=self.seq_lens[:self.num_prefills],
219221
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
220222
max_query_len=self.max_query_len,
@@ -243,6 +245,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
243245
num_prefill_tokens=0,
244246
num_decode_tokens=self.num_decode_tokens,
245247
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
248+
multi_modal_placeholder_index_maps=None,
246249
seq_lens=None,
247250
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
248251
max_query_len=None,

0 commit comments

Comments
 (0)