Skip to content

Commit 566cd27

Browse files
authored
[torch.compile] rework test plans (#9866)
Signed-off-by: youkaichao <youkaichao@gmail.com>
1 parent 37a4947 commit 566cd27

File tree

4 files changed

+226
-31
lines changed

4 files changed

+226
-31
lines changed
Lines changed: 95 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dataclasses
12
from typing import Dict, List, Optional
23

34
import pytest
@@ -8,33 +9,109 @@
89
from ..utils import compare_all_settings
910

1011

12+
@dataclasses.dataclass
13+
class TestSetting:
14+
model: str
15+
model_args: List[str]
16+
pp_size: int
17+
tp_size: int
18+
attn_backend: str
19+
method: str
20+
fullgraph: bool
21+
22+
23+
# representative settings for testing
24+
test_settings = [
25+
# basic llama model
26+
TestSetting(
27+
model="meta-llama/Llama-3.2-1B",
28+
model_args=[],
29+
pp_size=2,
30+
tp_size=2,
31+
attn_backend="FLASHINFER",
32+
method="generate",
33+
fullgraph=True,
34+
),
35+
# llama model with quantization
36+
TestSetting(
37+
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
38+
model_args=["--quantization", "gptq"],
39+
pp_size=1,
40+
tp_size=1,
41+
attn_backend="FLASH_ATTN",
42+
method="generate",
43+
fullgraph=True,
44+
),
45+
# MoE model
46+
TestSetting(
47+
model="ibm/PowerMoE-3b",
48+
model_args=[],
49+
pp_size=1,
50+
tp_size=2,
51+
attn_backend="FLASH_ATTN",
52+
method="generate",
53+
fullgraph=True,
54+
),
55+
# embedding model
56+
TestSetting(
57+
model="BAAI/bge-multilingual-gemma2",
58+
model_args=["--task", "embedding"],
59+
pp_size=1,
60+
tp_size=1,
61+
attn_backend="FLASHINFER",
62+
method="encode",
63+
fullgraph=True,
64+
),
65+
# vision language model
66+
TestSetting(
67+
model="microsoft/Phi-3.5-vision-instruct",
68+
model_args=["--trust-remote-code", "--max-model-len", "2048"],
69+
pp_size=2,
70+
tp_size=1,
71+
attn_backend="FLASH_ATTN",
72+
method="generate_with_image",
73+
fullgraph=False,
74+
),
75+
]
76+
77+
1178
# we cannot afford testing the full Catesian product
1279
# of all models and all levels
13-
@pytest.mark.parametrize(
14-
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
15-
[
16-
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True),
17-
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
18-
["--quantization", "compressed-tensors"
19-
], 1, 1, "FLASH_ATTN", "generate", True),
20-
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
21-
# TODO: add multi-modality test for llava
22-
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
23-
])
24-
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
25-
method, fullgraph):
80+
@pytest.mark.parametrize("test_setting", test_settings)
81+
def test_compile_correctness(test_setting: TestSetting):
2682
# this test is run under multiple suits, with different GPUs.
2783
# make sure we only run the test with correct CUDA devices.
2884
# don't use "<", as it will duplicate the tests.
85+
model = test_setting.model
86+
model_args = test_setting.model_args
87+
pp_size = test_setting.pp_size
88+
tp_size = test_setting.tp_size
89+
attn_backend = test_setting.attn_backend
90+
method = test_setting.method
91+
fullgraph = test_setting.fullgraph
2992
if cuda_device_count_stateless() != pp_size * tp_size:
3093
pytest.skip("Not correct CUDA devices for the test.")
3194
import os
3295
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
33-
all_args = [["--enforce-eager"] + model_args + ["-pp", str(pp_size)] +
34-
["-tp", str(tp_size)]] * 3
35-
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
36-
# inductor will change the output, so we cannot compare them.
96+
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
97+
["-tp", str(tp_size)]
98+
3799
all_envs: List[Optional[Dict[str, str]]] = []
100+
101+
for level in [
102+
CompilationLevel.NO_COMPILATION,
103+
CompilationLevel.PIECEWISE,
104+
]:
105+
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
106+
107+
# inductor will change the output, so we only compare if the output
108+
# is close, not exactly the same.
109+
compare_all_settings(
110+
model, [final_args] * 2,
111+
all_envs,
112+
method=method if method != "generate" else "generate_close")
113+
all_envs.clear()
114+
38115
for level in [
39116
CompilationLevel.NO_COMPILATION,
40117
CompilationLevel.DYNAMO_AS_IS,
@@ -46,4 +123,4 @@ def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
46123
all_envs[-1][
47124
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore
48125

49-
compare_all_settings(model, all_args, all_envs, method=method)
126+
compare_all_settings(model, [final_args] * 3, all_envs, method=method)

tests/utils.py

Lines changed: 119 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import copy
23
import functools
34
import os
45
import signal
@@ -8,13 +9,14 @@
89
import warnings
910
from contextlib import contextmanager
1011
from pathlib import Path
11-
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
12+
from typing import Any, Callable, Dict, List, Optional, Type, Union
1213

1314
import openai
1415
import pytest
1516
import requests
17+
import torch
1618
from openai.types.completion import Completion
17-
from typing_extensions import ParamSpec, assert_never
19+
from typing_extensions import ParamSpec
1820

1921
import vllm.envs as envs
2022
from tests.models.utils import TextTextLogprobs
@@ -272,6 +274,31 @@ def _test_completion(
272274
return results
273275

274276

277+
def _test_completion_close(
278+
client: openai.OpenAI,
279+
model: str,
280+
prompt: str,
281+
):
282+
results = []
283+
284+
# test with text prompt
285+
completion = client.completions.create(model=model,
286+
prompt=prompt,
287+
max_tokens=1,
288+
logprobs=5,
289+
temperature=0.0)
290+
291+
logporbs = completion.choices[0].logprobs.top_logprobs[0]
292+
logporbs = {k: round(v, 2) for k, v in logporbs.items()}
293+
294+
results.append({
295+
"test": "completion_close",
296+
"logprobs": logporbs,
297+
})
298+
299+
return results
300+
301+
275302
def _test_embeddings(
276303
client: openai.OpenAI,
277304
model: str,
@@ -295,13 +322,81 @@ def _test_embeddings(
295322
return results
296323

297324

325+
def _test_image_text(
326+
client: openai.OpenAI,
327+
model_name: str,
328+
image_url: str,
329+
):
330+
results = []
331+
332+
# test pure text input
333+
messages = [{
334+
"role":
335+
"user",
336+
"content": [
337+
{
338+
"type": "text",
339+
"text": "How do you feel today?"
340+
},
341+
],
342+
}]
343+
344+
chat_completion = client.chat.completions.create(model=model_name,
345+
messages=messages,
346+
temperature=0.0,
347+
max_tokens=1,
348+
logprobs=True,
349+
top_logprobs=5)
350+
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
351+
352+
for x in top_logprobs:
353+
x.logprob = round(x.logprob, 2)
354+
355+
results.append({
356+
"test": "pure_text",
357+
"logprobs": top_logprobs,
358+
})
359+
360+
messages = [{
361+
"role":
362+
"user",
363+
"content": [
364+
{
365+
"type": "image_url",
366+
"image_url": {
367+
"url": image_url
368+
}
369+
},
370+
{
371+
"type": "text",
372+
"text": "What's in this image?"
373+
},
374+
],
375+
}]
376+
377+
chat_completion = client.chat.completions.create(model=model_name,
378+
messages=messages,
379+
temperature=0.0,
380+
max_tokens=1,
381+
logprobs=True,
382+
top_logprobs=5)
383+
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
384+
385+
results.append({
386+
"test": "text_image",
387+
"logprobs": top_logprobs,
388+
})
389+
390+
return results
391+
392+
298393
def compare_two_settings(model: str,
299394
arg1: List[str],
300395
arg2: List[str],
301396
env1: Optional[Dict[str, str]] = None,
302397
env2: Optional[Dict[str, str]] = None,
303398
*,
304-
method: Literal["generate", "encode"] = "generate",
399+
method: str = "generate",
305400
max_wait_seconds: Optional[float] = None) -> None:
306401
"""
307402
Launch API server with two different sets of arguments/environments
@@ -328,7 +423,7 @@ def compare_all_settings(model: str,
328423
all_args: List[List[str]],
329424
all_envs: List[Optional[Dict[str, str]]],
330425
*,
331-
method: Literal["generate", "encode"] = "generate",
426+
method: str = "generate",
332427
max_wait_seconds: Optional[float] = None) -> None:
333428
"""
334429
Launch API server with several different sets of arguments/environments
@@ -397,10 +492,17 @@ def compare_all_settings(model: str,
397492

398493
if method == "generate":
399494
results += _test_completion(client, model, prompt, token_ids)
495+
elif method == "generate_close":
496+
results += _test_completion_close(client, model, prompt)
497+
elif method == "generate_with_image":
498+
results += _test_image_text(
499+
client, model,
500+
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
501+
)
400502
elif method == "encode":
401503
results += _test_embeddings(client, model, prompt)
402504
else:
403-
assert_never(method)
505+
raise ValueError(f"Unknown method: {method}")
404506

405507
if i > 0:
406508
# if any setting fails, raise an error early
@@ -410,6 +512,18 @@ def compare_all_settings(model: str,
410512
compare_envs = all_envs[i]
411513
for ref_result, compare_result in zip(ref_results,
412514
compare_results):
515+
ref_result = copy.deepcopy(ref_result)
516+
compare_result = copy.deepcopy(compare_result)
517+
if "embedding" in ref_result and method == "encode":
518+
ref_embedding = torch.tensor(ref_result["embedding"])
519+
compare_embedding = torch.tensor(
520+
compare_result["embedding"])
521+
mse = ((ref_embedding - compare_embedding)**2).mean()
522+
assert mse < 1e-6, (
523+
f"Embedding for {model=} are not the same.\n"
524+
f"mse={mse}\n")
525+
del ref_result["embedding"]
526+
del compare_result["embedding"]
413527
assert ref_result == compare_result, (
414528
f"Results for {model=} are not the same.\n"
415529
f"{ref_args=} {ref_envs=}\n"

vllm/model_executor/models/llava.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,13 +493,9 @@ def forward(
493493
:class:`LlavaImageInputs`
494494
"""
495495
if intermediate_tensors is not None:
496-
input_ids = None
497496
inputs_embeds = None
498497
else:
499-
# always pass the input via `inputs_embeds`
500-
# to make sure the computation graph is consistent
501498
image_input = self._parse_and_validate_image_input(**kwargs)
502-
503499
if image_input is not None:
504500
vision_embeddings = self._process_image_input(image_input)
505501
inputs_embeds = self.language_model.model.get_input_embeddings(
@@ -511,7 +507,11 @@ def forward(
511507
else:
512508
inputs_embeds = self.language_model.model.get_input_embeddings(
513509
input_ids)
514-
input_ids = None
510+
511+
# always pass the input via `inputs_embeds`
512+
# to make sure the computation graph is consistent
513+
# for `torch.compile` integration
514+
input_ids = None
515515

516516
hidden_states = self.language_model.model(input_ids,
517517
positions,

vllm/model_executor/models/phi3v.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,6 @@ def forward(self,
679679
intermediate_tensors: Optional[IntermediateTensors] = None,
680680
**kwargs: object):
681681
if intermediate_tensors is not None:
682-
input_ids = None
683682
inputs_embeds = None
684683
else:
685684
image_input = self._parse_and_validate_image_input(**kwargs)
@@ -690,9 +689,14 @@ def forward(self,
690689
inputs_embeds = merge_multimodal_embeddings(
691690
input_ids, inputs_embeds, vision_embeddings,
692691
self.image_token_id)
693-
input_ids = None
694692
else:
695-
inputs_embeds = None
693+
inputs_embeds = self.language_model.model.embed_tokens(
694+
input_ids)
695+
696+
# always pass the input via `inputs_embeds`
697+
# to make sure the computation graph is consistent
698+
# for `torch.compile` integration
699+
input_ids = None
696700

697701
hidden_states = self.language_model.model(input_ids,
698702
positions,

0 commit comments

Comments
 (0)