|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +# Adapted from https://huggingface.co/docs/transformers/perplexity |
| 4 | +from typing import Optional, cast |
| 5 | + |
| 6 | +import pytest |
| 7 | +import torch |
| 8 | +from datasets import load_dataset |
| 9 | + |
| 10 | +from tests.models.utils import (GenerateModelInfo, |
| 11 | + TokensTextLogprobsPromptLogprobs) |
| 12 | +from vllm.logprobs import Logprob |
| 13 | + |
| 14 | +# See #24485 |
| 15 | +PPL_TOL = 0.01 |
| 16 | +MAX_LENGTH = 1024 |
| 17 | + |
| 18 | + |
| 19 | +@torch.inference_mode |
| 20 | +def wikitext_ppl_test(hf_runner, |
| 21 | + vllm_runner, |
| 22 | + model_info: GenerateModelInfo, |
| 23 | + max_length=MAX_LENGTH, |
| 24 | + vllm_extra_kwargs=None, |
| 25 | + atol=PPL_TOL): |
| 26 | + |
| 27 | + # A model family has many models with the same architecture, |
| 28 | + # and we don't need to test each one. |
| 29 | + if not model_info.enable_test: |
| 30 | + pytest.skip("Skipping test.") |
| 31 | + |
| 32 | + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") |
| 33 | + |
| 34 | + # Allow vllm to test using the given dtype, such as float32 |
| 35 | + vllm_extra_kwargs = vllm_extra_kwargs or {} |
| 36 | + vllm_extra_kwargs["dtype"] = model_info.dtype |
| 37 | + |
| 38 | + # Allow vllm to test using hf_overrides |
| 39 | + if model_info.hf_overrides is not None: |
| 40 | + vllm_extra_kwargs["hf_overrides"] = model_info.hf_overrides |
| 41 | + |
| 42 | + with vllm_runner(model_info.name, |
| 43 | + gpu_memory_utilization=0.7, |
| 44 | + max_model_len=max_length, |
| 45 | + max_num_seqs=1, |
| 46 | + enforce_eager=True, |
| 47 | + **vllm_extra_kwargs) as vllm_model: |
| 48 | + # Use max_num_seqs=1 to avoid OOM, |
| 49 | + # and batch different requests together. |
| 50 | + |
| 51 | + model_config = vllm_model.llm.llm_engine.model_config |
| 52 | + |
| 53 | + # Confirm whether vllm is using the correct architecture |
| 54 | + if model_info.architecture: |
| 55 | + assert (model_info.architecture in model_config.architectures) |
| 56 | + |
| 57 | + max_length = min(model_config.max_model_len - 1, max_length) |
| 58 | + stride = max_length |
| 59 | + |
| 60 | + tokenizer = vllm_model.llm.get_tokenizer() |
| 61 | + tokens = tokenizer.encode("\n\n".join(dataset["text"])) |
| 62 | + n_tokens = len(tokens) |
| 63 | + |
| 64 | + chunks = [] |
| 65 | + for begin_loc in range(0, n_tokens, stride): |
| 66 | + end_loc = min(begin_loc + max_length, n_tokens) |
| 67 | + chunks.append(tokens[begin_loc:end_loc]) |
| 68 | + |
| 69 | + outputs = vllm_model.generate_greedy_logprobs(prompts=chunks, |
| 70 | + max_tokens=1, |
| 71 | + num_logprobs=None, |
| 72 | + num_prompt_logprobs=0, |
| 73 | + use_tqdm=False) |
| 74 | + nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu") |
| 75 | + n_tokens = 0 |
| 76 | + for output in outputs: |
| 77 | + output = cast(TokensTextLogprobsPromptLogprobs, output) |
| 78 | + token_datas = cast(list[Optional[dict[int, Logprob]]], output[3]) |
| 79 | + |
| 80 | + assert token_datas[0] is None |
| 81 | + token_log_probs = [] |
| 82 | + for token_data in token_datas[1:]: |
| 83 | + assert token_data is not None |
| 84 | + assert len(token_data) == 1 |
| 85 | + token_log_prob = list(token_data.values())[0].logprob |
| 86 | + token_log_probs.append(token_log_prob) |
| 87 | + |
| 88 | + neg_log_likelihood = -torch.tensor( |
| 89 | + token_log_probs, dtype=torch.float32, device="cpu").sum() |
| 90 | + nll_sum += neg_log_likelihood |
| 91 | + n_tokens += len(token_log_probs) |
| 92 | + vllm_ppl = float(torch.exp(nll_sum / n_tokens)) |
| 93 | + vllm_dtype = model_config.dtype |
| 94 | + |
| 95 | + # Accelerate ppl test by setting Transformers ppl score to a constant |
| 96 | + if model_info.hf_ppl is None: |
| 97 | + with hf_runner( |
| 98 | + model_info.name, |
| 99 | + dtype=model_info.hf_dtype, |
| 100 | + ) as hf_model: |
| 101 | + nll_sum = torch.tensor(0., dtype=torch.float32, device="cpu") |
| 102 | + n_tokens = 0 |
| 103 | + for chunk in chunks: |
| 104 | + inputs = hf_model.wrap_device( |
| 105 | + {"input_ids": torch.tensor([chunk])}) |
| 106 | + input_ids = inputs["input_ids"] |
| 107 | + outputs = hf_model.model(input_ids, labels=input_ids) |
| 108 | + neg_log_likelihood = outputs.loss |
| 109 | + |
| 110 | + neg_log_likelihood = neg_log_likelihood.to(torch.float32).cpu() |
| 111 | + |
| 112 | + num_loss_tokens = len(chunk) - 1 |
| 113 | + nll_sum += neg_log_likelihood * num_loss_tokens |
| 114 | + n_tokens += num_loss_tokens |
| 115 | + |
| 116 | + hf_ppl = float(torch.exp(nll_sum / n_tokens)) |
| 117 | + hf_dtype = next(hf_model.model.parameters()).dtype |
| 118 | + else: |
| 119 | + hf_ppl = model_info.hf_ppl |
| 120 | + hf_dtype = "Constant" |
| 121 | + |
| 122 | + differ = (vllm_ppl - hf_ppl) / hf_ppl |
| 123 | + print("Model:", model_info.name) |
| 124 | + print("VLLM:", vllm_dtype, vllm_ppl) |
| 125 | + print("Transformers:", hf_dtype, hf_ppl) |
| 126 | + print("Difference (%):", differ * 100) |
| 127 | + |
| 128 | + # PPL the smaller, the better |
| 129 | + # We are not concerned that the vllm PPL is less than Transformers, |
| 130 | + # so we only perform one-sided testing. |
| 131 | + assert differ < atol |
0 commit comments