Skip to content

Commit bd98842

Browse files
authored
[CI] Add PPL test for generation models (#24485)
Signed-off-by: wang.yuqi <noooop@126.com>
1 parent d606988 commit bd98842

File tree

9 files changed

+211
-7
lines changed

9 files changed

+211
-7
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,16 @@ steps:
604604
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
605605
- pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)'
606606

607+
- label: Language Models Test (PPL)
608+
timeout_in_minutes: 110
609+
mirror_hardwares: [amdexperimental]
610+
optional: true
611+
source_file_dependencies:
612+
- vllm/
613+
- tests/models/language/generation_ppl_test
614+
commands:
615+
- pytest -v -s models/language/generation_ppl_test
616+
607617
- label: Language Models Test (Extended Pooling) # 36min
608618
timeout_in_minutes: 50
609619
mirror_hardwares: [amdexperimental]

tests/models/language/generation_ppl_test/__init__.py

Whitespace-only changes.
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# Adapted from https://huggingface.co/docs/transformers/perplexity
4+
from typing import Optional, cast
5+
6+
import pytest
7+
import torch
8+
from datasets import load_dataset
9+
10+
from tests.models.utils import (GenerateModelInfo,
11+
TokensTextLogprobsPromptLogprobs)
12+
from vllm.logprobs import Logprob
13+
14+
# See #24485
15+
PPL_TOL = 0.01
16+
MAX_LENGTH = 1024
17+
18+
19+
@torch.inference_mode
20+
def wikitext_ppl_test(hf_runner,
21+
vllm_runner,
22+
model_info: GenerateModelInfo,
23+
max_length=MAX_LENGTH,
24+
vllm_extra_kwargs=None,
25+
atol=PPL_TOL):
26+
27+
# A model family has many models with the same architecture,
28+
# and we don't need to test each one.
29+
if not model_info.enable_test:
30+
pytest.skip("Skipping test.")
31+
32+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
33+
34+
# Allow vllm to test using the given dtype, such as float32
35+
vllm_extra_kwargs = vllm_extra_kwargs or {}
36+
vllm_extra_kwargs["dtype"] = model_info.dtype
37+
38+
# Allow vllm to test using hf_overrides
39+
if model_info.hf_overrides is not None:
40+
vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides
41+
42+
with vllm_runner(model_info.name,
43+
gpu_memory_utilization=0.7,
44+
max_model_len=max_length,
45+
max_num_seqs=1,
46+
enforce_eager=True,
47+
**vllm_extra_kwargs) as vllm_model:
48+
# Use max_num_seqs=1 to avoid OOM,
49+
# and batch different requests together.
50+
51+
model_config = vllm_model.llm.llm_engine.model_config
52+
53+
# Confirm whether vllm is using the correct architecture
54+
if model_info.architecture:
55+
assert (model_info.architecture in model_config.architectures)
56+
57+
max_length = min(model_config.max_model_len - 1, max_length)
58+
stride = max_length
59+
60+
tokenizer = vllm_model.llm.get_tokenizer()
61+
tokens = tokenizer.encode("\n\n".join(dataset["text"]))
62+
n_tokens = len(tokens)
63+
64+
chunks = []
65+
for begin_loc in range(0, n_tokens, stride):
66+
end_loc = min(begin_loc + max_length, n_tokens)
67+
chunks.append(tokens[begin_loc:end_loc])
68+
69+
outputs = vllm_model.generate_greedy_logprobs(prompts=chunks,
70+
max_tokens=1,
71+
num_logprobs=None,
72+
num_prompt_logprobs=0,
73+
use_tqdm=False)
74+
nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu")
75+
n_tokens = 0
76+
for output in outputs:
77+
output = cast(TokensTextLogprobsPromptLogprobs, output)
78+
token_datas = cast(list[Optional[dict[int, Logprob]]], output[3])
79+
80+
assert token_datas[0] is None
81+
token_log_probs = []
82+
for token_data in token_datas[1:]:
83+
assert token_data is not None
84+
assert len(token_data) == 1
85+
token_log_prob = list(token_data.values())[0].logprob
86+
token_log_probs.append(token_log_prob)
87+
88+
neg_log_likelihood = -torch.tensor(
89+
token_log_probs, dtype=torch.float32, device="cpu").sum()
90+
nll_sum += neg_log_likelihood
91+
n_tokens += len(token_log_probs)
92+
vllm_ppl = float(torch.exp(nll_sum / n_tokens))
93+
vllm_dtype = model_config.dtype
94+
95+
# Accelerate ppl test by setting Transformers ppl score to a constant
96+
if model_info.hf_ppl is None:
97+
with hf_runner(
98+
model_info.name,
99+
dtype=model_info.hf_dtype,
100+
) as hf_model:
101+
nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu")
102+
n_tokens = 0
103+
for chunk in chunks:
104+
inputs = hf_model.wrap_device(
105+
{"input_ids": torch.tensor([chunk])})
106+
input_ids = inputs["input_ids"]
107+
outputs = hf_model.model(input_ids, labels=input_ids)
108+
neg_log_likelihood = outputs.loss
109+
110+
neg_log_likelihood = neg_log_likelihood.to(torch.float32).cpu()
111+
112+
num_loss_tokens = len(chunk) - 1
113+
nll_sum += neg_log_likelihood * num_loss_tokens
114+
n_tokens += num_loss_tokens
115+
116+
hf_ppl = float(torch.exp(nll_sum / n_tokens))
117+
hf_dtype = next(hf_model.model.parameters()).dtype
118+
else:
119+
hf_ppl = model_info.hf_ppl
120+
hf_dtype = "Constant"
121+
122+
differ = (vllm_ppl - hf_ppl) / hf_ppl
123+
print("Model:", model_info.name)
124+
print("VLLM:", vllm_dtype, vllm_ppl)
125+
print("Transformers:", hf_dtype, hf_ppl)
126+
print("Difference (%):", differ * 100)
127+
128+
# PPL the smaller, the better
129+
# We are not concerned that the vllm PPL is less than Transformers,
130+
# so we only perform one-sided testing.
131+
assert differ < atol
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
5+
from tests.models.utils import GenerateModelInfo
6+
7+
from .ppl_utils import wikitext_ppl_test
8+
9+
MODELS = [
10+
GenerateModelInfo("google/gemma-2b"),
11+
GenerateModelInfo("google/gemma-2-2b"),
12+
GenerateModelInfo("google/gemma-3-4b-it"),
13+
]
14+
15+
16+
@pytest.mark.parametrize("model_info", MODELS)
17+
def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo):
18+
wikitext_ppl_test(hf_runner, vllm_runner, model_info)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import pytest
4+
5+
from tests.models.utils import GenerateModelInfo
6+
7+
from .ppl_utils import wikitext_ppl_test
8+
9+
MODELS = [GenerateModelInfo("openai-community/gpt2-large")]
10+
11+
12+
@pytest.mark.parametrize("model_info", MODELS)
13+
def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo):
14+
wikitext_ppl_test(hf_runner, vllm_runner, model_info)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
6+
from tests.models.utils import GenerateModelInfo
7+
8+
from .ppl_utils import wikitext_ppl_test
9+
10+
MODELS = [
11+
GenerateModelInfo("Qwen/Qwen3-0.6B"),
12+
GenerateModelInfo("Qwen/Qwen3-0.6B-FP8"),
13+
# transformers:
14+
# Loading a GPTQ quantized model requires optimum, gptqmodel
15+
# GenerateModelInfo("Qwen/Qwen3-0.6B-GPTQ-Int8"),
16+
]
17+
18+
19+
@pytest.mark.parametrize("model_info", MODELS)
20+
def test_ppl(hf_runner, vllm_runner, model_info: GenerateModelInfo):
21+
wikitext_ppl_test(hf_runner, vllm_runner, model_info)

tests/models/language/pooling/embed_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def correctness_test_embed_models(hf_runner,
5959

6060
with hf_runner(
6161
model_info.name,
62-
dtype="float32",
62+
dtype=model_info.hf_dtype,
6363
is_sentence_transformer=True,
6464
) as hf_model:
6565

tests/models/language/pooling/mteb_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def mteb_test_embed_models(hf_runner,
213213
if model_info.mteb_score is None:
214214
with hf_runner(model_info.name,
215215
is_sentence_transformer=True,
216-
dtype="float32") as hf_model:
216+
dtype=model_info.hf_dtype) as hf_model:
217217

218218
# e.g. setting default parameters for the encode method of hf_runner
219219
if hf_model_callback is not None:
@@ -278,9 +278,12 @@ def run_mteb_rerank(cross_encoder, tasks, languages):
278278
return main_score
279279

280280

281-
def mteb_test_rerank_models_hf(hf_runner, model_name, hf_model_callback=None):
281+
def mteb_test_rerank_models_hf(hf_runner,
282+
model_name,
283+
hf_dtype="float32",
284+
hf_model_callback=None):
282285
with hf_runner(model_name, is_cross_encoder=True,
283-
dtype="float32") as hf_model:
286+
dtype=hf_dtype) as hf_model:
284287

285288
original_predict = hf_model.predict
286289

@@ -357,7 +360,7 @@ def mteb_test_rerank_models(hf_runner,
357360
# SentenceTransformers mteb score to a constant
358361
if model_info.mteb_score is None:
359362
st_main_score, st_dtype = mteb_test_rerank_models_hf(
360-
hf_runner, model_info.name, hf_model_callback)
363+
hf_runner, model_info.name, model_info.hf_dtype, hf_model_callback)
361364
else:
362365
st_main_score = model_info.mteb_score
363366
st_dtype = "Constant"

tests/models/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,14 +347,15 @@ class ModelInfo:
347347
name: str
348348
architecture: str = ""
349349
dtype: str = "auto"
350+
hf_dtype: str = "float32"
350351
hf_overrides: Optional[dict[str, Any]] = None
351352
default_pooling_type: str = ""
352-
mteb_score: Optional[float] = None
353353
enable_test: bool = True
354354

355355

356356
@dataclass
357357
class EmbedModelInfo(ModelInfo):
358+
mteb_score: Optional[float] = None
358359
is_matryoshka: bool = False
359360
matryoshka_dimensions: Optional[list[int]] = None
360361

@@ -371,7 +372,7 @@ class LASTPoolingEmbedModelInfo(EmbedModelInfo):
371372

372373
@dataclass
373374
class RerankModelInfo(ModelInfo):
374-
pass
375+
mteb_score: Optional[float] = None
375376

376377

377378
@dataclass
@@ -384,6 +385,12 @@ class LASTPoolingRerankModelInfo(RerankModelInfo):
384385
default_pooling_type: str = "LAST"
385386

386387

388+
@dataclass
389+
class GenerateModelInfo(ModelInfo):
390+
hf_dtype: str = "auto"
391+
hf_ppl: Optional[float] = None
392+
393+
387394
def dummy_hf_overrides(
388395
hf_config: PretrainedConfig,
389396
*,

0 commit comments

Comments
 (0)