Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions tests/v1/sample/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import itertools
import math
from collections.abc import Generator
from typing import get_args

import pytest
import torch

from tests.utils import large_gpu_mark
from tests.v1.sample.utils import (
BatchLogprobsComposition,
BatchLogprobsSpecType,
Expand All @@ -17,6 +19,7 @@
)
from vllm import SamplingParams
from vllm.config.model import LogprobsMode
from vllm.distributed import cleanup_dist_env_and_memory

from ...conftest import HfRunner, VllmRunner

Expand Down Expand Up @@ -508,3 +511,94 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
if logprobs_mode in ("raw_logits", "processed_logits"):
assert positive_values > 0
del llm


@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
@pytest.mark.parametrize(
"model_setup",
[
pytest.param(
(
"eagle",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
),
marks=large_gpu_mark(min_gb=32),
),
],
)
def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str],
monkeypatch: pytest.MonkeyPatch,
):
"""Spec decode logprobs should match those of the base model.
Args:
logprobs_mode: logprobs mode.
model_setup: Spec decode method, base model name, and
draft model name.
"""
from vllm import LLM

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
prompt = "Hello world"
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
)
method, model_name, spec_model_name = model_setup
max_model_len = 256

# Run base LLM.
ref_llm = LLM(
model=model_name,
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
)
ref_results = ref_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from reference LLM.
ref_logprobs = []
for output in ref_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
ref_logprobs.append(logprobs[token_id])
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

# Run spec decode LLM.
spec_llm = LLM(
model_name,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": max_model_len,
},
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for output in spec_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
spec_logprobs.append(logprobs[token_id])
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

# Per-token logprobs are expected to be the same.
assert len(ref_logprobs) == len(spec_logprobs)
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token
Loading