Skip to content

Commit 6644796

Browse files
[V1][spec decode] return logprobs for spec decoding (#26060)
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com> Co-authored-by: Nick Hill <nhill@redhat.com>
1 parent ff93cc8 commit 6644796

File tree

8 files changed

+393
-187
lines changed

8 files changed

+393
-187
lines changed

tests/v1/sample/test_logprobs.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import itertools
5+
import math
56
from collections.abc import Generator
67
from typing import get_args
78

89
import pytest
910
import torch
1011

12+
from tests.utils import large_gpu_mark
1113
from tests.v1.sample.utils import (
1214
BatchLogprobsComposition,
1315
BatchLogprobsSpecType,
@@ -17,6 +19,7 @@
1719
)
1820
from vllm import SamplingParams
1921
from vllm.config.model import LogprobsMode
22+
from vllm.distributed import cleanup_dist_env_and_memory
2023

2124
from ...conftest import HfRunner, VllmRunner
2225

@@ -508,3 +511,94 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
508511
if logprobs_mode in ("raw_logits", "processed_logits"):
509512
assert positive_values > 0
510513
del llm
514+
515+
516+
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
517+
@pytest.mark.parametrize(
518+
"model_setup",
519+
[
520+
pytest.param(
521+
(
522+
"eagle",
523+
"meta-llama/Llama-3.1-8B-Instruct",
524+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
525+
),
526+
marks=large_gpu_mark(min_gb=32),
527+
),
528+
],
529+
)
530+
def test_spec_decode_logprobs(
531+
logprobs_mode: LogprobsMode,
532+
model_setup: tuple[str, str, str],
533+
monkeypatch: pytest.MonkeyPatch,
534+
):
535+
"""Spec decode logprobs should match those of the base model.
536+
537+
Args:
538+
logprobs_mode: logprobs mode.
539+
model_setup: Spec decode method, base model name, and
540+
draft model name.
541+
"""
542+
from vllm import LLM
543+
544+
with monkeypatch.context() as m:
545+
m.setenv("VLLM_USE_V1", "1")
546+
prompt = "Hello world"
547+
sampling_params = SamplingParams(
548+
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
549+
)
550+
method, model_name, spec_model_name = model_setup
551+
max_model_len = 256
552+
553+
# Run base LLM.
554+
ref_llm = LLM(
555+
model=model_name,
556+
max_logprobs=5,
557+
max_model_len=max_model_len,
558+
seed=42,
559+
logprobs_mode=logprobs_mode,
560+
gpu_memory_utilization=0.4,
561+
)
562+
ref_results = ref_llm.generate([prompt], sampling_params)
563+
# Collect logprobs outputs from reference LLM.
564+
ref_logprobs = []
565+
for output in ref_results[0].outputs:
566+
for logprobs in output.logprobs:
567+
for token_id in logprobs:
568+
ref_logprobs.append(logprobs[token_id])
569+
del ref_llm
570+
torch.cuda.empty_cache()
571+
cleanup_dist_env_and_memory()
572+
573+
# Run spec decode LLM.
574+
spec_llm = LLM(
575+
model_name,
576+
speculative_config={
577+
"method": method,
578+
"model": spec_model_name,
579+
"num_speculative_tokens": 3,
580+
"max_model_len": max_model_len,
581+
},
582+
max_logprobs=5,
583+
max_model_len=max_model_len,
584+
seed=42,
585+
logprobs_mode=logprobs_mode,
586+
gpu_memory_utilization=0.4,
587+
)
588+
spec_results = spec_llm.generate([prompt], sampling_params)
589+
# Collect logprobs outputs from spec decode LLM.
590+
spec_logprobs = []
591+
for output in spec_results[0].outputs:
592+
for logprobs in output.logprobs:
593+
for token_id in logprobs:
594+
spec_logprobs.append(logprobs[token_id])
595+
del spec_llm
596+
torch.cuda.empty_cache()
597+
cleanup_dist_env_and_memory()
598+
599+
# Per-token logprobs are expected to be the same.
600+
assert len(ref_logprobs) == len(spec_logprobs)
601+
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
602+
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
603+
assert ref_logprob.rank == spec_logprob.rank
604+
assert ref_logprob.decoded_token == spec_logprob.decoded_token

0 commit comments

Comments
 (0)