Skip to content

Commit cff6a1f

Browse files
[CI/Build] Reuse code for checking output consistency (#5988)
1 parent bcc6a09 commit cff6a1f

11 files changed

+125
-75
lines changed

tests/basic_correctness/test_basic_correctness.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from vllm import LLM
1010

11+
from ..models.utils import check_outputs_equal
12+
1113
MODELS = [
1214
"facebook/opt-125m",
1315
"meta-llama/Llama-2-7b-hf",
@@ -46,10 +48,9 @@ def test_models(
4648
gpu_memory_utilization=0.7) as vllm_model:
4749
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
4850

49-
for i in range(len(example_prompts)):
50-
hf_output_ids, hf_output_str = hf_outputs[i]
51-
vllm_output_ids, vllm_output_str = vllm_outputs[i]
52-
assert hf_output_str == vllm_output_str, (
53-
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
54-
assert hf_output_ids == vllm_output_ids, (
55-
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
51+
check_outputs_equal(
52+
outputs_0_lst=hf_outputs,
53+
outputs_1_lst=vllm_outputs,
54+
name_0="hf",
55+
name_1="vllm",
56+
)

tests/basic_correctness/test_chunked_prefill.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
"""
99
import pytest
1010

11+
from ..models.utils import check_outputs_equal
12+
1113
MODELS = [
1214
"facebook/opt-125m",
1315
"meta-llama/Llama-2-7b-hf",
@@ -54,10 +56,9 @@ def test_models(
5456
) as vllm_model:
5557
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
5658

57-
for i in range(len(example_prompts)):
58-
hf_output_ids, hf_output_str = hf_outputs[i]
59-
vllm_output_ids, vllm_output_str = vllm_outputs[i]
60-
assert hf_output_str == vllm_output_str, (
61-
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
62-
assert hf_output_ids == vllm_output_ids, (
63-
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
59+
check_outputs_equal(
60+
outputs_0_lst=hf_outputs,
61+
outputs_1_lst=vllm_outputs,
62+
name_0="hf",
63+
name_1="vllm",
64+
)

tests/basic_correctness/test_preemption.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
1313
ENABLE_ARTIFICIAL_PREEMPT)
1414

15+
from ..models.utils import check_outputs_equal
16+
1517
MODELS = [
1618
"facebook/opt-125m",
1719
]
@@ -94,13 +96,13 @@ def test_preemption(
9496
total_preemption = (
9597
vllm_model.model.llm_engine.scheduler.num_cumulative_preemption)
9698

97-
for i in range(len(example_prompts)):
98-
hf_output_ids, hf_output_str = hf_outputs[i]
99-
vllm_output_ids, vllm_output_str = vllm_outputs[i]
100-
assert hf_output_str == vllm_output_str, (
101-
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
102-
assert hf_output_ids == vllm_output_ids, (
103-
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
99+
check_outputs_equal(
100+
outputs_0_lst=hf_outputs,
101+
outputs_1_lst=vllm_outputs,
102+
name_0="hf",
103+
name_1="vllm",
104+
)
105+
104106
assert ("is preempted by PreemptionMode.RECOMPUTE mode because there "
105107
"is not enough KV cache space." in caplog_vllm.text)
106108
# Ensure the count bucket of request-level histogram metrics matches

tests/distributed/test_basic_distributed_correctness.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import pytest
1818
import torch
1919

20+
from ..models.utils import check_outputs_equal
21+
2022
MODELS = [
2123
os.environ["TEST_DIST_MODEL"],
2224
]
@@ -48,10 +50,9 @@ def test_models(
4850
) as vllm_model:
4951
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
5052

51-
for i in range(len(example_prompts)):
52-
hf_output_ids, hf_output_str = hf_outputs[i]
53-
vllm_output_ids, vllm_output_str = vllm_outputs[i]
54-
assert hf_output_str == vllm_output_str, (
55-
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
56-
assert hf_output_ids == vllm_output_ids, (
57-
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
53+
check_outputs_equal(
54+
outputs_0_lst=hf_outputs,
55+
outputs_1_lst=vllm_outputs,
56+
name_0="hf",
57+
name_1="vllm",
58+
)

tests/distributed/test_chunked_prefill_distributed.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import pytest
1717
import torch
1818

19+
from ..models.utils import check_outputs_equal
20+
1921
MODELS = [
2022
os.environ["TEST_DIST_MODEL"],
2123
]
@@ -59,10 +61,9 @@ def test_models(
5961
) as vllm_model:
6062
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
6163

62-
for i in range(len(example_prompts)):
63-
hf_output_ids, hf_output_str = hf_outputs[i]
64-
vllm_output_ids, vllm_output_str = vllm_outputs[i]
65-
assert hf_output_str == vllm_output_str, (
66-
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
67-
assert hf_output_ids == vllm_output_ids, (
68-
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
64+
check_outputs_equal(
65+
outputs_0_lst=hf_outputs,
66+
outputs_1_lst=vllm_outputs,
67+
name_0="hf",
68+
name_1="vllm",
69+
)

tests/models/test_big_models.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import pytest
88
import torch
99

10+
from .utils import check_outputs_equal
11+
1012
MODELS = [
1113
"meta-llama/Llama-2-7b-hf",
1214
# "mistralai/Mistral-7B-v0.1", # Tested by test_mistral.py
@@ -40,13 +42,12 @@ def test_models(
4042
with vllm_runner(model, dtype=dtype) as vllm_model:
4143
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
4244

43-
for i in range(len(example_prompts)):
44-
hf_output_ids, hf_output_str = hf_outputs[i]
45-
vllm_output_ids, vllm_output_str = vllm_outputs[i]
46-
assert hf_output_str == vllm_output_str, (
47-
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
48-
assert hf_output_ids == vllm_output_ids, (
49-
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
45+
check_outputs_equal(
46+
outputs_0_lst=hf_outputs,
47+
outputs_1_lst=vllm_outputs,
48+
name_0="hf",
49+
name_1="vllm",
50+
)
5051

5152

5253
@pytest.mark.parametrize("model", MODELS)

tests/models/test_llava.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from vllm.config import VisionLanguageConfig
77

88
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
9+
from .utils import check_outputs_equal
910

1011
pytestmark = pytest.mark.vlm
1112

@@ -109,14 +110,15 @@ def run_test(
109110
max_tokens,
110111
images=vllm_images)
111112

112-
for i in range(len(HF_IMAGE_PROMPTS)):
113-
hf_output_ids, hf_output_str = hf_outputs[i]
114-
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
115-
vllm_outputs[i], vlm_config, model_id)
116-
assert hf_output_str == vllm_output_str, (
117-
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
118-
assert hf_output_ids == vllm_output_ids, (
119-
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
113+
check_outputs_equal(
114+
hf_outputs,
115+
[
116+
vllm_to_hf_output(vllm_output, vlm_config, model_id)
117+
for vllm_output in vllm_outputs
118+
],
119+
name_0="hf",
120+
name_1="vllm",
121+
)
120122

121123

122124
@pytest.mark.parametrize("model_and_config", model_and_vl_config)

tests/models/test_llava_next.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from vllm.config import VisionLanguageConfig
77

88
from ..conftest import IMAGE_ASSETS
9+
from .utils import check_outputs_equal
910

1011
pytestmark = pytest.mark.vlm
1112

@@ -115,11 +116,12 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,
115116
max_tokens,
116117
images=vllm_images)
117118

118-
for i in range(len(HF_IMAGE_PROMPTS)):
119-
hf_output_ids, hf_output_str = hf_outputs[i]
120-
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
121-
vllm_outputs[i], vlm_config, model_id)
122-
assert hf_output_str == vllm_output_str, (
123-
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
124-
assert hf_output_ids == vllm_output_ids, (
125-
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
119+
check_outputs_equal(
120+
hf_outputs,
121+
[
122+
vllm_to_hf_output(vllm_output, vlm_config, model_id)
123+
for vllm_output in vllm_outputs
124+
],
125+
name_0="hf",
126+
name_1="vllm",
127+
)

tests/models/test_models.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
"""
88
import pytest
99

10+
from .utils import check_outputs_equal
11+
1012
MODELS = [
1113
"facebook/opt-125m",
1214
"gpt2",
@@ -41,13 +43,12 @@ def test_models(
4143
with vllm_runner(model, dtype=dtype) as vllm_model:
4244
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
4345

44-
for i in range(len(example_prompts)):
45-
hf_output_ids, hf_output_str = hf_outputs[i]
46-
vllm_output_ids, vllm_output_str = vllm_outputs[i]
47-
assert hf_output_str == vllm_output_str, (
48-
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
49-
assert hf_output_ids == vllm_output_ids, (
50-
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
46+
check_outputs_equal(
47+
outputs_0_lst=hf_outputs,
48+
outputs_1_lst=vllm_outputs,
49+
name_0="hf",
50+
name_1="vllm",
51+
)
5152

5253

5354
@pytest.mark.parametrize("model", MODELS)

tests/models/test_phi3v.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm.utils import is_cpu
88

99
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
10+
from .utils import check_outputs_equal
1011

1112
pytestmark = pytest.mark.vlm
1213

@@ -124,14 +125,15 @@ def run_test(
124125
max_tokens,
125126
images=vllm_images)
126127

127-
for i in range(len(HF_IMAGE_PROMPTS)):
128-
hf_output_ids, hf_output_str = hf_outputs[i]
129-
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
130-
vllm_outputs[i], vlm_config, model_id)
131-
assert hf_output_str == vllm_output_str, (
132-
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
133-
assert hf_output_ids == vllm_output_ids, (
134-
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
128+
check_outputs_equal(
129+
hf_outputs,
130+
[
131+
vllm_to_hf_output(vllm_output, vlm_config, model_id)
132+
for vllm_output in vllm_outputs
133+
],
134+
name_0="hf",
135+
name_1="vllm",
136+
)
135137

136138

137139
# Since we use _attn_implementation="eager" for hf_runner, here is

tests/models/utils.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,43 @@
1-
def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1):
2-
"""Compare the logprobs of two sequences generated by different models,
1+
from typing import Dict, List, Tuple
2+
3+
TokensText = Tuple[List[int], str]
4+
5+
6+
def check_outputs_equal(outputs_0_lst: List[TokensText],
7+
outputs_1_lst: List[TokensText], name_0: str,
8+
name_1: str):
9+
"""
10+
Compare the two sequences generated by different models,
11+
which should be equal.
12+
"""
13+
assert len(outputs_0_lst) == len(outputs_1_lst)
14+
15+
for prompt_idx, (outputs_0,
16+
outputs_1) in enumerate(zip(outputs_0_lst,
17+
outputs_1_lst)):
18+
output_ids_0, output_str_0 = outputs_0
19+
output_ids_1, output_str_1 = outputs_1
20+
21+
assert output_str_0 == output_str_1, (f"Test{prompt_idx}:"
22+
f"\n{name_0}:\t{output_str_0!r}"
23+
f"\n{name_1}:\t{output_str_1!r}")
24+
assert output_ids_0 == output_ids_1, (f"Test{prompt_idx}:"
25+
f"\n{name_0}:\t{output_str_0!r}"
26+
f"\n{name_1}:\t{output_str_1!r}")
27+
28+
29+
TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
30+
31+
32+
def check_logprobs_close(outputs_0_lst: List[TokensTextLogprobs],
33+
outputs_1_lst: List[TokensTextLogprobs], name_0: str,
34+
name_1: str):
35+
"""
36+
Compare the logprobs of two sequences generated by different models,
337
which should be similar but not necessarily equal.
438
"""
39+
assert len(outputs_0_lst) == len(outputs_1_lst)
40+
541
# Loop through responses to each prompt.
642
for prompt_idx, (outputs_0,
743
outputs_1) in enumerate(zip(outputs_0_lst,

0 commit comments

Comments
 (0)