Skip to content

Commit 80ca1e6

Browse files
authored
[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)
1 parent 614aa51 commit 80ca1e6

14 files changed

+482
-210
lines changed

tests/samplers/test_typical_acceptance_sampler.py

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
5252
return draft_token_ids
5353

5454

55+
def get_acceptance_sampler(
56+
posterior_threshold: float = 0.03,
57+
posterior_alpha: float = 0.9,
58+
disable_bonus_tokens: bool = False,
59+
strict_mode: bool = False,
60+
) -> TypicalAcceptanceSampler:
61+
"""
62+
Initializes and returns a TypicalAcceptanceSampler.
63+
"""
64+
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
65+
disable_bonus_tokens, strict_mode)
66+
67+
5568
@pytest.mark.parametrize("k", list(range(1, 6)))
5669
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
5770
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
6477
different combinations of k, vocab_size, batch_size and num devices.
6578
"""
6679
torch.set_default_device(device)
67-
typical_acceptance_sampler = TypicalAcceptanceSampler()
80+
typical_acceptance_sampler = get_acceptance_sampler()
6881
typical_acceptance_sampler.init_gpu_tensors(rank=0)
6982
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
7083
bonus_token_ids = torch.randint(low=0,
@@ -76,7 +89,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
7689
size=(batch_size, k),
7790
dtype=torch.int64)
7891
# Verify that sampling succeeds for all cases.
79-
typical_acceptance_sampler(target_probs, bonus_token_ids, draft_token_ids)
92+
typical_acceptance_sampler(target_probs,
93+
bonus_token_ids,
94+
draft_probs=None,
95+
draft_token_ids=draft_token_ids)
8096

8197

8298
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@@ -94,7 +110,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
94110
batch_size = 5
95111
vocab_size = 30_000
96112
torch.set_default_device(device)
97-
typical_acceptance_sampler = TypicalAcceptanceSampler(strict_mode=True)
113+
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
98114
typical_acceptance_sampler.init_gpu_tensors(rank=0)
99115
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
100116
bonus_token_ids = torch.randint(low=0,
@@ -125,8 +141,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
125141
oob_token_ids[0][0] = rogue_token_id
126142

127143
with pytest.raises(AssertionError):
128-
typical_acceptance_sampler(target_probs, bonus_token_ids,
129-
draft_token_ids)
144+
typical_acceptance_sampler(target_probs,
145+
bonus_token_ids,
146+
draft_probs=None,
147+
draft_token_ids=draft_token_ids)
130148

131149

132150
@pytest.mark.parametrize("seed", list(range(10)))
@@ -151,7 +169,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
151169
batch_size = 5
152170
vocab_size = 30_000
153171
torch.set_default_device(device)
154-
typical_acceptance_sampler = TypicalAcceptanceSampler(
172+
typical_acceptance_sampler = get_acceptance_sampler(
155173
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
156174
typical_acceptance_sampler.init_gpu_tensors(rank=0)
157175
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
@@ -163,9 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
163181
high=vocab_size,
164182
size=(batch_size, 1),
165183
dtype=torch.int64)
166-
output_token_ids = typical_acceptance_sampler(target_probs,
167-
bonus_token_ids,
168-
draft_token_ids)
184+
output_token_ids = typical_acceptance_sampler(
185+
target_probs,
186+
bonus_token_ids,
187+
draft_probs=None,
188+
draft_token_ids=draft_token_ids)
169189
# We are using a uniform target probability distribution.
170190
# For a uniform distribution the entropy is very high and it
171191
# should lead to all draft tokens being accepted. Verify that.
@@ -203,7 +223,7 @@ def test_temperature_zero_target_distribution(seed: int,
203223
vocab_size = 30_000
204224
torch.set_default_device(device)
205225

206-
typical_acceptance_sampler = TypicalAcceptanceSampler(
226+
typical_acceptance_sampler = get_acceptance_sampler(
207227
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
208228
typical_acceptance_sampler.init_gpu_tensors(rank=0)
209229
# Simulate temperature 0 probability distribution for target probabilities
@@ -224,9 +244,11 @@ def test_temperature_zero_target_distribution(seed: int,
224244
# 1.0 tokens in the target distribution we will reject all of them and
225245
# fallback to the greedy sampling for selecting 1 token for each sequence.
226246
# Verify the same.
227-
output_token_ids = typical_acceptance_sampler(target_probs,
228-
bonus_token_ids,
229-
draft_token_ids)
247+
output_token_ids = typical_acceptance_sampler(
248+
target_probs,
249+
bonus_token_ids,
250+
draft_probs=None,
251+
draft_token_ids=draft_token_ids)
230252
assert output_token_ids.shape[0] == batch_size
231253
assert output_token_ids.shape[1] == (k + 1)
232254
assert torch.all(output_token_ids[:, -1] == -1)
@@ -261,7 +283,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
261283
batch_size = 4
262284
vocab_size = 30_000
263285
torch.set_default_device(device)
264-
typical_acceptance_sampler = TypicalAcceptanceSampler(
286+
typical_acceptance_sampler = get_acceptance_sampler(
265287
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
266288
typical_acceptance_sampler.init_gpu_tensors(rank=0)
267289
# For sequences 0 and 2 set the distribution to a temperature
@@ -277,9 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
277299
high=vocab_size,
278300
size=(batch_size, 1),
279301
dtype=torch.int64)
280-
output_token_ids = typical_acceptance_sampler(target_probs,
281-
bonus_token_ids,
282-
draft_token_ids)
302+
output_token_ids = typical_acceptance_sampler(
303+
target_probs,
304+
bonus_token_ids,
305+
draft_probs=None,
306+
draft_token_ids=draft_token_ids)
283307
# verify the shape of output_token_ids
284308
assert output_token_ids.shape[0] == batch_size
285309
assert output_token_ids.shape[1] == (k + 1)
@@ -326,7 +350,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
326350
batch_size = 1
327351
vocab_size = 30_000
328352
torch.set_default_device(device)
329-
typical_acceptance_sampler = TypicalAcceptanceSampler(
353+
typical_acceptance_sampler = get_acceptance_sampler(
330354
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
331355
typical_acceptance_sampler.init_gpu_tensors(rank=0)
332356
# Create a temperature zero target probability distribution and ensure
@@ -339,9 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
339363
high=vocab_size,
340364
size=(batch_size, 1),
341365
dtype=torch.int64)
342-
output_token_ids = typical_acceptance_sampler(target_probs,
343-
bonus_token_ids,
344-
draft_token_ids)
366+
output_token_ids = typical_acceptance_sampler(
367+
target_probs,
368+
bonus_token_ids,
369+
draft_probs=None,
370+
draft_token_ids=draft_token_ids)
345371
assert output_token_ids.shape[0] == batch_size
346372
assert output_token_ids.shape[1] == (k + 1)
347373
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
@@ -357,9 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
357383
batch_size, k, vocab_size, zero_temperature_token_ids)
358384
draft_token_ids = torch.cat(
359385
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
360-
output_token_ids = typical_acceptance_sampler(target_probs,
361-
bonus_token_ids,
362-
draft_token_ids)
386+
output_token_ids = typical_acceptance_sampler(
387+
target_probs,
388+
bonus_token_ids,
389+
draft_probs=None,
390+
draft_token_ids=draft_token_ids)
363391
assert output_token_ids.shape[0] == batch_size
364392
assert output_token_ids.shape[1] == (k + 1)
365393
assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2])
@@ -384,7 +412,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
384412
batch_size = 1
385413
vocab_size = 30_000
386414
torch.set_default_device(device)
387-
typical_acceptance_sampler = TypicalAcceptanceSampler(
415+
typical_acceptance_sampler = get_acceptance_sampler(
388416
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
389417
typical_acceptance_sampler.init_gpu_tensors(rank=0)
390418
# Simulate temperature 0 probability distribution for target
@@ -402,9 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
402430
high=vocab_size,
403431
size=(batch_size, 1),
404432
dtype=torch.int64)
405-
output_token_ids = typical_acceptance_sampler(target_probs,
406-
bonus_token_ids,
407-
draft_token_ids)
433+
output_token_ids = typical_acceptance_sampler(
434+
target_probs,
435+
bonus_token_ids,
436+
draft_probs=None,
437+
draft_token_ids=draft_token_ids)
408438
assert output_token_ids.shape[0] == batch_size
409439
assert output_token_ids.shape[1] == (k + 1)
410440
assert torch.all(output_token_ids[:, 1:-1] == -1)
@@ -418,9 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
418448
posterior_threshold=0.0,
419449
posterior_alpha=0.0)
420450
typical_acceptance_sampler.init_gpu_tensors(rank=0)
421-
output_token_ids = typical_acceptance_sampler(target_probs,
422-
bonus_token_ids,
423-
draft_token_ids)
451+
output_token_ids = typical_acceptance_sampler(
452+
target_probs,
453+
bonus_token_ids,
454+
draft_probs=None,
455+
draft_token_ids=draft_token_ids)
424456
assert output_token_ids.shape[0] == batch_size
425457
assert output_token_ids.shape[1] == (k + 1)
426458
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
@@ -451,7 +483,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
451483
batch_size = 5
452484
vocab_size = 30_000
453485
torch.set_default_device(device)
454-
typical_acceptance_sampler = TypicalAcceptanceSampler(
486+
typical_acceptance_sampler = get_acceptance_sampler(
455487
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
456488
typical_acceptance_sampler.init_gpu_tensors(rank=0)
457489
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)

tests/spec_decode/e2e/test_multistep_correctness.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,15 @@
1111
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
1212
equality. This gives us good coverage of temp=0.
1313
14+
At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the
15+
highest probability in the target distribution are accepted. Therefore, we can
16+
expect greedy equality for the TypicalAcceptanceSampler at temp=0.
17+
1418
For temp>0, we rely on unit tests on the rejection sampler to verify that the
1519
output distribution is the same with spec decode vs. no spec decode (this would
16-
be prohibitively expensive to run with a real model).
20+
be prohibitively expensive to run with a real model). Similarly, for the
21+
TypicalAcceptance sampler also, we rely on unit tests to validate temp>0
22+
test cases.
1723
1824
NOTE: Speculative decoding's distribution equality requires that the measured
1925
distributions of the target model and proposal model be deterministic given the
@@ -611,3 +617,49 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
611617
batch_size,
612618
max_output_len=output_len,
613619
force_output_len=True)
620+
621+
622+
@pytest.mark.parametrize(
623+
"common_llm_kwargs",
624+
[{
625+
"model": "JackFram/llama-160m",
626+
627+
# Skip cuda graph recording for fast test.
628+
"enforce_eager": True,
629+
630+
# Required for spec decode.
631+
"use_v2_block_manager": True
632+
}])
633+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
634+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
635+
@pytest.mark.parametrize(
636+
"test_llm_kwargs",
637+
[
638+
{
639+
"speculative_model": "JackFram/llama-68m",
640+
"num_speculative_tokens": k,
641+
"spec_decoding_acceptance_method": "typical_acceptance_sampler"
642+
}
643+
# Try a range of common k.
644+
for k in [1, 2, 3]
645+
])
646+
@pytest.mark.parametrize("batch_size", [1, 32])
647+
@pytest.mark.parametrize(
648+
"output_len",
649+
[
650+
# Use smaller output len for fast test.
651+
32,
652+
])
653+
@pytest.mark.parametrize("seed", [1])
654+
def test_typical_acceptance_sampling(baseline_llm_generator,
655+
test_llm_generator, batch_size: int,
656+
output_len: int):
657+
"""Verify that speculative decoding produces exact equality to without spec
658+
decode with TypicalAcceptanceSampler as the draft token acceptance
659+
sampling method.
660+
"""
661+
run_greedy_equality_correctness_test(baseline_llm_generator,
662+
test_llm_generator,
663+
batch_size,
664+
max_output_len=output_len,
665+
force_output_len=True)

tests/spec_decode/test_dynamic_spec_decode.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,35 @@
33
import pytest
44
import torch
55

6-
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
76
from vllm.sequence import ExecuteModelRequest
87
from vllm.spec_decode.metrics import AsyncMetricsCollector
98
from vllm.spec_decode.multi_step_worker import MultiStepWorker
109
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
1110
from vllm.spec_decode.top1_proposer import Top1Proposer
1211

12+
from .test_utils import mock_spec_decode_sampler
1313
from .utils import create_batch, mock_worker
1414

1515

1616
@pytest.mark.parametrize('queue_size', [4])
1717
@pytest.mark.parametrize('batch_size', [1])
1818
@pytest.mark.parametrize('k', [1])
19+
@pytest.mark.parametrize("acceptance_sampler_method",
20+
["rejection_sampler", "typical_acceptance_sampler"])
1921
@torch.inference_mode()
20-
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int):
22+
def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int,
23+
acceptance_sampler_method: str):
2124
"""Verify that speculative tokens are disabled when the batch size
2225
exceeds the threshold.
2326
"""
2427
disable_by_batch_size = 3
25-
2628
draft_worker = mock_worker(cls=MultiStepWorker)
2729
target_worker = mock_worker()
28-
rejection_sampler = MagicMock(spec=RejectionSampler)
2930
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
3031
worker = SpecDecodeWorker(proposer_worker=draft_worker,
3132
scorer_worker=target_worker,
32-
rejection_sampler=rejection_sampler,
33+
spec_decode_sampler=mock_spec_decode_sampler(
34+
acceptance_sampler_method),
3335
metrics_collector=metrics_collector,
3436
disable_by_batch_size=disable_by_batch_size)
3537

0 commit comments

Comments
 (0)