Skip to content

Commit 48bc380

Browse files
[V1][spec decode] return logprobs for spec decoding
1 parent 8db2939 commit 48bc380

File tree

6 files changed

+248
-148
lines changed

6 files changed

+248
-148
lines changed

tests/v1/sample/test_rejection_sampler.py

Lines changed: 71 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from typing import Any, Optional
4+
from unittest.mock import Mock
45

56
import pytest
67
import torch
@@ -11,14 +12,32 @@
1112
from vllm.v1.sample.metadata import SamplingMetadata
1213
from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
1314
RejectionSampler)
15+
from vllm.v1.sample.sampler import Sampler, SamplerOutput
1416
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
1517

1618
DEVICE = current_platform.device_type
1719

1820

1921
@pytest.fixture
2022
def rejection_sampler():
21-
return RejectionSampler()
23+
mock_sampler = Mock(spec=Sampler)
24+
return RejectionSampler(mock_sampler)
25+
26+
27+
def mock_sampler_output(rejection_sampler: RejectionSampler,
28+
bonus_token_ids: torch.Tensor):
29+
rejection_sampler.sampler.return_value = SamplerOutput(
30+
sampled_token_ids=bonus_token_ids, logprobs_tensors=None)
31+
32+
33+
def create_spec_decode_metadata(spec_tokens: list[list[int]],
34+
logits: torch.Tensor) -> SpecDecodeMetadata:
35+
metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device)
36+
metadata.target_logits_indices = torch.arange(logits.shape[0])
37+
# Output bonus token ids are mocked, so the bonus logit indices should
38+
# be empty.
39+
metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32)
40+
return metadata
2241

2342

2443
def create_logits_tensor(output_token_ids: list[list[int]],
@@ -83,20 +102,19 @@ def test_perfect_match(rejection_sampler):
83102
logits = create_logits_tensor(output_tokens)
84103
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
85104
device=logits.device)
86-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
87-
device=logits.device)
105+
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
88106

107+
mock_sampler_output(rejection_sampler, bonus_token_tensor)
89108
output = rejection_sampler(
90109
spec_decode_metadata,
91110
draft_probs=None,
92-
target_logits=logits,
93-
bonus_token_ids=bonus_token_tensor,
111+
logits=logits,
94112
sampling_metadata=metadata,
95113
)
96114
expected = torch.tensor([[1, 2, 3, 4]],
97115
dtype=torch.int,
98116
device=logits.device)
99-
assert torch.equal(output, expected)
117+
assert torch.equal(output.sampled_token_ids, expected)
100118

101119

102120
def test_early_mismatch(rejection_sampler):
@@ -108,22 +126,21 @@ def test_early_mismatch(rejection_sampler):
108126
logits = create_logits_tensor(output_tokens)
109127
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
110128
device=logits.device)
111-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
112-
device=logits.device)
129+
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
113130

131+
mock_sampler_output(rejection_sampler, bonus_token_tensor)
114132
output = rejection_sampler(
115133
spec_decode_metadata,
116134
draft_probs=None,
117-
target_logits=logits,
118-
bonus_token_ids=bonus_token_tensor,
135+
logits=logits,
119136
sampling_metadata=metadata,
120137
)
121138
expected = torch.tensor(
122139
[[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
123140
dtype=torch.int,
124141
device=logits.device,
125142
)
126-
assert torch.equal(output, expected)
143+
assert torch.equal(output.sampled_token_ids, expected)
127144

128145

129146
def test_multiple_sequences(rejection_sampler):
@@ -136,20 +153,19 @@ def test_multiple_sequences(rejection_sampler):
136153
logits = create_logits_tensor(output_tokens)
137154
bonus_token_tensor = torch.tensor(
138155
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
139-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
140-
device=logits.device)
156+
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
141157

158+
mock_sampler_output(rejection_sampler, bonus_token_tensor)
142159
output = rejection_sampler(
143160
spec_decode_metadata,
144161
draft_probs=None,
145-
target_logits=logits,
146-
bonus_token_ids=bonus_token_tensor,
162+
logits=logits,
147163
sampling_metadata=metadata,
148164
)
149165
expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
150166
dtype=torch.int,
151167
device=logits.device)
152-
assert torch.equal(output, expected)
168+
assert torch.equal(output.sampled_token_ids, expected)
153169

154170

155171
def test_single_token_sequence(rejection_sampler):
@@ -161,18 +177,17 @@ def test_single_token_sequence(rejection_sampler):
161177
logits = create_logits_tensor(output_tokens)
162178
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
163179
device=logits.device)
164-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
165-
device=logits.device)
180+
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
166181

182+
mock_sampler_output(rejection_sampler, bonus_token_tensor)
167183
output = rejection_sampler(
168184
spec_decode_metadata,
169185
draft_probs=None,
170-
target_logits=logits,
171-
bonus_token_ids=bonus_token_tensor,
186+
logits=logits,
172187
sampling_metadata=metadata,
173188
)
174189
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
175-
assert torch.equal(output, expected)
190+
assert torch.equal(output.sampled_token_ids, expected)
176191

177192

178193
def test_empty_sequence(rejection_sampler):
@@ -184,18 +199,17 @@ def test_empty_sequence(rejection_sampler):
184199
logits = create_logits_tensor(output_tokens)
185200
bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
186201
device=logits.device)
187-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
188-
device=logits.device)
202+
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
189203

204+
mock_sampler_output(rejection_sampler, bonus_token_tensor)
190205
output = rejection_sampler(
191206
spec_decode_metadata,
192207
draft_probs=None,
193-
target_logits=logits,
194-
bonus_token_ids=bonus_token_tensor,
208+
logits=logits,
195209
sampling_metadata=metadata,
196210
)
197211
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
198-
assert torch.equal(output, expected)
212+
assert torch.equal(output.sampled_token_ids, expected)
199213

200214

201215
def test_multiple_mismatches(rejection_sampler):
@@ -208,14 +222,13 @@ def test_multiple_mismatches(rejection_sampler):
208222
logits = create_logits_tensor(output_tokens)
209223
bonus_token_tensor = torch.tensor(
210224
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
211-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
212-
device=logits.device)
225+
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
213226

227+
mock_sampler_output(rejection_sampler, bonus_token_tensor)
214228
output = rejection_sampler(
215229
spec_decode_metadata,
216230
draft_probs=None,
217-
target_logits=logits,
218-
bonus_token_ids=bonus_token_tensor,
231+
logits=logits,
219232
sampling_metadata=metadata,
220233
)
221234
expected = torch.tensor(
@@ -224,7 +237,7 @@ def test_multiple_mismatches(rejection_sampler):
224237
dtype=torch.int,
225238
device=logits.device,
226239
)
227-
assert torch.equal(output, expected)
240+
assert torch.equal(output.sampled_token_ids, expected)
228241

229242

230243
@pytest.mark.parametrize(
@@ -242,20 +255,19 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
242255
logits = create_logits_tensor(output_tokens)
243256
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
244257
device=logits.device)
245-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
246-
device=logits.device)
258+
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
247259

260+
mock_sampler_output(rejection_sampler, bonus_token_tensor)
248261
output = rejection_sampler(
249262
spec_decode_metadata,
250263
draft_probs=None,
251-
target_logits=logits,
252-
bonus_token_ids=bonus_token_tensor,
264+
logits=logits,
253265
sampling_metadata=metadata,
254266
)
255267
expected_tensor = torch.tensor(expected,
256268
dtype=torch.int,
257269
device=logits.device)
258-
assert torch.equal(output, expected_tensor)
270+
assert torch.equal(output.sampled_token_ids, expected_tensor)
259271

260272

261273
########################### Tests for Random Sampling ###################
@@ -305,17 +317,18 @@ def test_deterministic_when_seeded(
305317
sampling_metadata = create_sampling_metadata(all_greedy=False,
306318
temperature=temperature,
307319
generators=seeded_seqs)
308-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
309-
draft_token_ids.tolist(), device=DEVICE)
320+
spec_decode_metadata = create_spec_decode_metadata(
321+
draft_token_ids.tolist(), target_logits)
322+
323+
mock_sampler_output(rejection_sampler, bonus_token_ids)
310324
rep_result = rejection_sampler(
311325
spec_decode_metadata,
312-
draft_probs=draft_probs,
313-
target_logits=target_logits,
314-
bonus_token_ids=bonus_token_ids,
326+
draft_probs=None,
327+
logits=target_logits,
315328
sampling_metadata=sampling_metadata,
316329
)
317330

318-
results.append(rep_result)
331+
results.append(rep_result.sampled_token_ids)
319332

320333
for i in range(batch_size):
321334
if seeded_mask[i]:
@@ -424,7 +437,9 @@ def estimate_rejection_sampling_pdf(
424437
Returns:
425438
Estimated probability distribution of the output tokens.
426439
"""
427-
rejection_sampler = RejectionSampler()
440+
# Mock the sampler that TreeRejectionSampler uses
441+
mock_sampler = Mock(spec=Sampler)
442+
rejection_sampler = RejectionSampler(mock_sampler)
428443
num_tokens = num_samples * k
429444
# Repeat draft probs num_samples * k times.
430445
draft_probs = draft_probs.reshape(1, 1,
@@ -447,16 +462,17 @@ def estimate_rejection_sampling_pdf(
447462
temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
448463
sampling_metadata = create_sampling_metadata(all_greedy=False,
449464
temperature=temperature)
450-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
451-
draft_token_ids.tolist(), device=bonus_token_ids.device)
452-
output_token_ids = rejection_sampler(
465+
spec_decode_metadata = create_spec_decode_metadata(
466+
draft_token_ids.tolist(), target_logits)
467+
468+
mock_sampler_output(rejection_sampler, bonus_token_ids)
469+
sampler_output = rejection_sampler(
453470
spec_decode_metadata,
454471
draft_probs=draft_probs,
455-
target_logits=target_logits,
456-
bonus_token_ids=bonus_token_ids,
472+
logits=target_logits,
457473
sampling_metadata=sampling_metadata,
458474
)
459-
output_token_ids = output_token_ids[:, :-1].flatten()
475+
output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten()
460476

461477
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
462478
device="cpu"),
@@ -496,22 +512,20 @@ def _test_masked_logits(
496512
device=DEVICE)
497513

498514
# Create spec decode metadata
499-
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
500-
draft_token_ids,
501-
device=DEVICE,
502-
)
515+
spec_decode_metadata = create_spec_decode_metadata(draft_token_ids,
516+
target_logits)
503517

504518
# Run rejection sampling
505-
output_token_ids = rejection_sampler(
519+
mock_sampler_output(rejection_sampler, bonus_token_ids)
520+
output = rejection_sampler(
506521
spec_decode_metadata,
507522
draft_probs=draft_probs,
508-
target_logits=target_logits,
509-
bonus_token_ids=bonus_token_ids,
523+
logits=target_logits,
510524
sampling_metadata=sampling_metadata,
511525
)
512526

513527
# Remove bonus tokens and reshape
514-
output_token_ids = output_token_ids[:, :-1].flatten().tolist()
528+
output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist()
515529

516530
# Check that all sampled tokens are within the unmasked indices.
517531
for i in range(num_tokens):

vllm/v1/engine/logprobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
6666
assert self.logprobs is not None
6767
assert self.cumulative_logprob is not None
6868

69-
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
69+
token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists
7070

7171
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
7272
token_ids_lst):

0 commit comments

Comments
 (0)