Skip to content

Commit 8fe10f3

Browse files
[V1][spec decode] return logprobs for spec decoding
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
1 parent f6b3bcb commit 8fe10f3

File tree

3 files changed

+57
-47
lines changed

3 files changed

+57
-47
lines changed

tests/v1/sample/test_logprobs.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111

12+
from tests.utils import large_gpu_mark
1213
from tests.v1.sample.utils import (
1314
BatchLogprobsComposition,
1415
BatchLogprobsSpecType,
@@ -18,6 +19,7 @@
1819
)
1920
from vllm import SamplingParams
2021
from vllm.config.model import LogprobsMode
22+
from vllm.distributed import cleanup_dist_env_and_memory
2123

2224
from ...conftest import HfRunner, VllmRunner
2325

@@ -515,11 +517,14 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
515517
@pytest.mark.parametrize(
516518
"model_setup",
517519
[
518-
(
519-
"eagle",
520-
"meta-llama/Llama-3.1-8B-Instruct",
521-
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
522-
)
520+
pytest.param(
521+
(
522+
"eagle",
523+
"meta-llama/Llama-3.1-8B-Instruct",
524+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
525+
),
526+
marks=large_gpu_mark(min_gb=32),
527+
),
523528
],
524529
)
525530
def test_spec_decode_logprobs(
@@ -543,17 +548,27 @@ def test_spec_decode_logprobs(
543548
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
544549
)
545550
method, model_name, spec_model_name = model_setup
551+
max_model_len = 256
546552

547553
# Run base LLM.
548554
ref_llm = LLM(
549555
model=model_name,
550556
max_logprobs=5,
551-
max_model_len=2048,
557+
max_model_len=max_model_len,
552558
seed=42,
553559
logprobs_mode=logprobs_mode,
560+
gpu_memory_utilization=0.4,
554561
)
555562
ref_results = ref_llm.generate([prompt], sampling_params)
563+
# Collect logprobs outputs from reference LLM.
564+
ref_logprobs = []
565+
for output in ref_results[0].outputs:
566+
for logprobs in output.logprobs:
567+
for token_id in logprobs:
568+
ref_logprobs.append(logprobs[token_id])
556569
del ref_llm
570+
torch.cuda.empty_cache()
571+
cleanup_dist_env_and_memory()
557572

558573
# Run spec decode LLM.
559574
spec_llm = LLM(
@@ -562,27 +577,24 @@ def test_spec_decode_logprobs(
562577
"method": method,
563578
"model": spec_model_name,
564579
"num_speculative_tokens": 3,
565-
"max_model_len": 2048,
580+
"max_model_len": max_model_len,
566581
},
567582
max_logprobs=5,
568-
max_model_len=2048,
583+
max_model_len=max_model_len,
569584
seed=42,
570585
logprobs_mode=logprobs_mode,
586+
gpu_memory_utilization=0.4,
571587
)
572588
spec_results = spec_llm.generate([prompt], sampling_params)
573-
del spec_llm
574-
575-
# Collect logprobs outputs from reference and spec decode LLMs.
576-
ref_logprobs = []
577-
for output in ref_results[0].outputs:
578-
for logprobs in output.logprobs:
579-
for token_id in logprobs:
580-
ref_logprobs.append(logprobs[token_id])
589+
# Collect logprobs outputs from spec decode LLM.
581590
spec_logprobs = []
582591
for output in spec_results[0].outputs:
583592
for logprobs in output.logprobs:
584593
for token_id in logprobs:
585594
spec_logprobs.append(logprobs[token_id])
595+
del spec_llm
596+
torch.cuda.empty_cache()
597+
cleanup_dist_env_and_memory()
586598

587599
# Per-token logprobs are expected to be the same.
588600
assert len(ref_logprobs) == len(spec_logprobs)

tests/v1/sample/test_rejection_sampler.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
@pytest.fixture
2222
def rejection_sampler():
2323
mock_sampler = Mock(spec=Sampler)
24+
mock_sampler.logprobs_mode = "raw_logprobs"
2425
return RejectionSampler(mock_sampler)
2526

2627

@@ -469,8 +470,8 @@ def estimate_rejection_sampling_pdf(
469470
Returns:
470471
Estimated probability distribution of the output tokens.
471472
"""
472-
# Mock the sampler that TreeRejectionSampler uses
473473
mock_sampler = Mock(spec=Sampler)
474+
mock_sampler.logprobs_mode = "raw_logprobs"
474475
rejection_sampler = RejectionSampler(mock_sampler)
475476
num_tokens = num_samples * k
476477
# Repeat draft probs num_samples * k times.
@@ -674,19 +675,19 @@ def test_frequency_penalties(rejection_sampler):
674675
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
675676
spec_tokens, device=logits.device
676677
)
678+
mock_sampler_output(rejection_sampler, bonus_token_tensor)
677679
output = rejection_sampler(
678680
spec_decode_metadata,
679681
draft_probs=None,
680-
target_logits=logits,
681-
bonus_token_ids=bonus_token_tensor,
682+
logits=logits,
682683
sampling_metadata=metadata,
683684
)
684685
expected = torch.tensor(
685686
[[1, 15, -1, -1], [7, -1, -1, -1], [1, 1, 15, -1]],
686687
dtype=torch.int,
687688
device=logits.device,
688689
)
689-
assert torch.equal(output, expected)
690+
assert torch.equal(output.sampled_token_ids, expected)
690691

691692

692693
def test_bad_words(rejection_sampler):
@@ -716,14 +717,12 @@ def test_bad_words(rejection_sampler):
716717
bonus_token_tensor = torch.tensor(
717718
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
718719
)
719-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
720-
spec_tokens, device=logits.device
721-
)
720+
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
721+
mock_sampler_output(rejection_sampler, bonus_token_tensor)
722722
output = rejection_sampler(
723723
spec_decode_metadata,
724724
draft_probs=None,
725-
target_logits=logits,
726-
bonus_token_ids=bonus_token_tensor,
725+
logits=logits,
727726
sampling_metadata=metadata,
728727
)
729728

@@ -732,7 +731,7 @@ def test_bad_words(rejection_sampler):
732731
dtype=torch.int,
733732
device=logits.device,
734733
)
735-
assert torch.equal(output, expected)
734+
assert torch.equal(output.sampled_token_ids, expected)
736735

737736

738737
def test_allowed_token_ids(rejection_sampler):
@@ -765,14 +764,12 @@ def test_allowed_token_ids(rejection_sampler):
765764
bonus_token_tensor = torch.tensor(
766765
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
767766
)
768-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
769-
spec_tokens, device=logits.device
770-
)
767+
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
768+
mock_sampler_output(rejection_sampler, bonus_token_tensor)
771769
output = rejection_sampler(
772770
spec_decode_metadata,
773771
draft_probs=None,
774-
target_logits=logits,
775-
bonus_token_ids=bonus_token_tensor,
772+
logits=logits,
776773
sampling_metadata=metadata,
777774
)
778775

@@ -781,4 +778,4 @@ def test_allowed_token_ids(rejection_sampler):
781778
dtype=torch.int,
782779
device=logits.device,
783780
)
784-
assert torch.equal(output, expected)
781+
assert torch.equal(output.sampled_token_ids, expected)

vllm/v1/sample/rejection_sampler.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ class RejectionSampler(nn.Module):
5151
def __init__(self, sampler: Sampler):
5252
super().__init__()
5353
self.sampler = sampler
54-
self.return_processed_logprobs = self.sampler.logprobs_mode.startswith(
55-
"processed"
56-
)
54+
logprobs_mode = self.sampler.logprobs_mode
55+
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
56+
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
5757

5858
def forward(
5959
self,
@@ -107,10 +107,9 @@ def forward(
107107
# Override the logprobs mode to return logits because they are
108108
# needed later to compute the accepted token logprobs.
109109
logprobs_mode_override="processed_logits"
110-
if self.return_processed_logprobs
110+
if self.is_processed_logprobs_mode
111111
else "raw_logits",
112112
)
113-
bonus_logits = bonus_sampler_output.logprobs_tensors.logprobs
114113
bonus_token_ids = bonus_sampler_output.sampled_token_ids
115114

116115
# Just like `bonus_logits`, `target_logits` is a new tensor with
@@ -144,16 +143,21 @@ def forward(
144143
sampling_metadata,
145144
)
146145

147-
return SamplerOutput(
148-
sampled_token_ids=output_token_ids,
149-
logprobs_tensors=self._get_logprobs_tensors(
146+
logprobs_tensors = None
147+
if sampling_metadata.max_num_logprobs:
148+
bonus_logits = bonus_sampler_output.logprobs_tensors.logprobs
149+
logprobs_tensors = self._get_logprobs_tensors(
150150
sampling_metadata,
151151
metadata,
152152
logits,
153-
target_logits if self.return_processed_logprobs else raw_target_logits,
153+
target_logits if self.is_processed_logprobs_mode else raw_target_logits,
154154
bonus_logits,
155155
output_token_ids,
156-
),
156+
)
157+
158+
return SamplerOutput(
159+
sampled_token_ids=output_token_ids,
160+
logprobs_tensors=logprobs_tensors,
157161
)
158162

159163
def _get_logprobs_tensors(
@@ -164,10 +168,7 @@ def _get_logprobs_tensors(
164168
target_logits: torch.Tensor,
165169
bonus_logits: torch.Tensor,
166170
sampled_token_ids: torch.Tensor,
167-
) -> LogprobsTensors | None:
168-
if not sampling_metadata.max_num_logprobs:
169-
return None
170-
171+
) -> LogprobsTensors:
171172
cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens)
172173
cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1]
173174

@@ -190,7 +191,7 @@ def _get_logprobs_tensors(
190191
accepted_logits = final_logits[accepted_logit_indices]
191192
accepted_logprobs = (
192193
accepted_logits
193-
if self.logprobs_mode.endswith("logits")
194+
if self.is_logits_logprobs_mode
194195
else self.sampler.compute_logprobs(accepted_logits)
195196
)
196197
accepted_tokens = sampled_token_ids[accepted_mask]

0 commit comments

Comments
 (0)