Skip to content

Commit 79d64c4

Browse files
authored
[Speculative decoding 1/9] Optimized rejection sampler (#2336)
1 parent 74cd5ab commit 79d64c4

File tree

2 files changed

+784
-0
lines changed

2 files changed

+784
-0
lines changed
Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
"""Tests for rejection sampling."""
2+
import pytest
3+
from typing import List, Tuple
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
from vllm.model_executor.utils import set_random_seed
9+
10+
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
11+
12+
13+
def mock_causal_accepted_tensor(
14+
k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
15+
"""Generate an "accepted" tensor which should yield causally-accepted tokens
16+
up to last accepted indices.
17+
18+
Tokens after last_accepted_indices+1 may also be accepted, although they
19+
will not be causally accepted.
20+
"""
21+
batch_size = last_accepted_indices.shape[0]
22+
23+
accepted = (torch.arange(k).expand(batch_size, k) <=
24+
last_accepted_indices.unsqueeze(-1).broadcast_to(
25+
batch_size, k)).to(device="cuda")
26+
27+
# Sprinkle accepted values after the contiguous initial accepted values.
28+
# This replicates the behavior of rejection sampling, which may "accept"
29+
# a token that cannot be accepted because of causality.
30+
sprinkle_candidates = (
31+
torch.arange(k).expand(batch_size, k) >
32+
last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
33+
sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5
34+
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
35+
return accepted
36+
37+
38+
@pytest.mark.parametrize("seed", list(range(10)))
39+
@pytest.mark.parametrize(
40+
"which_tokens_accepted",
41+
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
42+
@torch.inference_mode()
43+
def test_correct_output_format(which_tokens_accepted: str, seed: int):
44+
"""Verify the output has correct format given predetermined accepted matrix.
45+
"""
46+
set_random_seed(seed)
47+
48+
batch_size = 10
49+
k = 5
50+
vocab_size = 3000
51+
52+
if which_tokens_accepted == "all_tokens_accepted":
53+
accepted = mock_causal_accepted_tensor(
54+
k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))
55+
elif which_tokens_accepted == "no_tokens_accepted":
56+
accepted = mock_causal_accepted_tensor(
57+
k, -torch.ones((batch_size, ), dtype=torch.long))
58+
elif which_tokens_accepted == "some_tokens_accepted":
59+
last_accepted_indices = torch.randint(low=-1,
60+
high=k,
61+
size=(batch_size, ))
62+
accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
63+
else:
64+
raise AssertionError()
65+
66+
recovered_token_ids = torch.randint(low=0,
67+
high=vocab_size,
68+
size=(batch_size, k),
69+
dtype=torch.int64,
70+
device="cuda")
71+
draft_token_ids = torch.randint(low=0,
72+
high=vocab_size,
73+
size=(batch_size, k),
74+
dtype=torch.int64,
75+
device="cuda")
76+
bonus_token_ids = torch.randint(low=0,
77+
high=vocab_size,
78+
size=(batch_size, 1),
79+
dtype=torch.int64,
80+
device="cuda")
81+
82+
rejection_sampler = RejectionSampler()
83+
rejection_sampler.init_gpu_tensors(rank=0)
84+
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
85+
accepted,
86+
recovered_token_ids,
87+
draft_token_ids,
88+
bonus_token_ids,
89+
)
90+
91+
if which_tokens_accepted == "all_tokens_accepted":
92+
# Expect all tokens to be equal to draft tokens.
93+
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
94+
95+
# Expect all bonus tokens to be included.
96+
assert torch.equal(output_token_ids[:, -1:], bonus_token_ids)
97+
elif which_tokens_accepted == "no_tokens_accepted":
98+
# Expect first token to be equal to recovered tokens.
99+
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
100+
101+
# Expect everything else to be -1.
102+
assert torch.equal(output_token_ids[:, 1:],
103+
torch.ones_like(output_token_ids[:, 1:]) * -1)
104+
elif which_tokens_accepted == "some_tokens_accepted":
105+
recovered_plus_bonus = torch.cat(
106+
(recovered_token_ids, bonus_token_ids), dim=-1)
107+
# Assert first rejected token is a recovered token or bonus token.
108+
assert torch.equal(
109+
recovered_plus_bonus[torch.arange(0, batch_size),
110+
last_accepted_indices + 1],
111+
output_token_ids[torch.arange(0, batch_size),
112+
last_accepted_indices + 1])
113+
114+
# Assert every subsequent token is -1.
115+
subsequent_mask = torch.arange(0, k + 1).expand(
116+
batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)
117+
assert torch.all(output_token_ids[subsequent_mask] == -1)
118+
119+
120+
@pytest.mark.parametrize("k", list(range(1, 6)))
121+
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
122+
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
123+
@torch.inference_mode()
124+
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int):
125+
rejection_sampler = RejectionSampler()
126+
rejection_sampler.init_gpu_tensors(rank=0)
127+
128+
draft_probs = torch.rand(batch_size,
129+
k,
130+
vocab_size,
131+
dtype=torch.float32,
132+
device="cuda")
133+
target_probs = torch.rand(batch_size,
134+
k,
135+
vocab_size,
136+
dtype=torch.float32,
137+
device="cuda")
138+
bonus_token_ids = torch.randint(low=0,
139+
high=vocab_size,
140+
size=(batch_size, 1),
141+
dtype=torch.int64,
142+
device="cuda")
143+
draft_token_ids = torch.randint(low=0,
144+
high=vocab_size,
145+
size=(batch_size, k),
146+
dtype=torch.int64,
147+
device="cuda")
148+
149+
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
150+
draft_token_ids)
151+
152+
153+
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
154+
@pytest.mark.parametrize("which_token_ids",
155+
["bonus_token_ids", "draft_token_ids"])
156+
@torch.inference_mode()
157+
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
158+
which_token_ids: str):
159+
k = 3
160+
batch_size = 5
161+
vocab_size = 30_000
162+
163+
rejection_sampler = RejectionSampler(strict_mode=True)
164+
rejection_sampler.init_gpu_tensors(rank=0)
165+
166+
draft_probs = torch.rand(batch_size,
167+
k,
168+
vocab_size,
169+
dtype=torch.float32,
170+
device="cuda")
171+
target_probs = torch.rand(batch_size,
172+
k,
173+
vocab_size,
174+
dtype=torch.float32,
175+
device="cuda")
176+
bonus_token_ids = torch.randint(low=0,
177+
high=vocab_size,
178+
size=(batch_size, 1),
179+
dtype=torch.int64,
180+
device="cuda")
181+
draft_token_ids = torch.randint(low=0,
182+
high=vocab_size,
183+
size=(batch_size, k),
184+
dtype=torch.int64,
185+
device="cuda")
186+
187+
oob_token_ids = None
188+
if which_token_ids == "bonus_token_ids":
189+
oob_token_ids = bonus_token_ids
190+
elif which_token_ids == "draft_token_ids":
191+
oob_token_ids = draft_token_ids
192+
else:
193+
raise AssertionError()
194+
195+
if above_or_below_vocab_range == "above":
196+
rogue_token_id = vocab_size + 1
197+
elif above_or_below_vocab_range == "below":
198+
rogue_token_id = -1
199+
else:
200+
raise AssertionError()
201+
202+
oob_token_ids[0][0] = rogue_token_id
203+
204+
with pytest.raises(AssertionError):
205+
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
206+
draft_token_ids)
207+
208+
209+
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
210+
@pytest.mark.parametrize("seed", list(range(5)))
211+
@torch.inference_mode()
212+
def test_rejection_sampling_approximates_target_distribution(
213+
seed: int, draft_and_target_probs_equal: bool):
214+
"""Verify rejection sampling approximates target distribution,
215+
despite sampling from a potentially distinct draft distribution.
216+
217+
This is done by first creating a random target probability
218+
distribution and a random draft probability distribution. We then
219+
sample token ids from the rejection sampler using these draft
220+
and target distributions. The samples are used to estimate
221+
the output probability distribution, which we expect to approximate
222+
the target distribution.
223+
224+
A basic distance metric is used to determine similarity between
225+
distributions.
226+
227+
We expect that as we increase the number of samples,
228+
the distance between the observed distribution and the target
229+
distribution decreases. To measure this, we compare the distance
230+
of the observed distribution against both the target distribution
231+
and a uniform random distribution. We expect the distance between
232+
the observed distribution and the target distribution to improve
233+
much more than the distance improvement between the observed
234+
distribution and the random distribution.
235+
236+
When draft_and_target_probs_equal=True, the draft and target
237+
probabilities are exactly equal. Rejection sampling should
238+
still work without any NaNs or exceptions.
239+
"""
240+
set_random_seed(seed)
241+
242+
helper = _CorrectnessTestHelper(
243+
vocab_size=10,
244+
rejection_sampler=RejectionSampler(),
245+
)
246+
247+
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
248+
draft_and_target_probs_equal)
249+
250+
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
251+
distance_wrt_reference = []
252+
distance_wrt_target = []
253+
254+
for num_samples in sample_sizes:
255+
(reference_vs_rejsample_dist,
256+
target_vs_rejsample_dist) = helper.run_and_compare_distributions(
257+
draft_probs,
258+
target_probs,
259+
reference_probs,
260+
num_samples,
261+
)
262+
263+
distance_wrt_reference.append(reference_vs_rejsample_dist)
264+
distance_wrt_target.append(target_vs_rejsample_dist)
265+
266+
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
267+
distance_wrt_target)
268+
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
269+
distance_wrt_reference)
270+
271+
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
272+
f"{reference_vs_rejsample_dist=:.05f}")
273+
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
274+
f"{relative_change_in_distance_wrt_reference=:.02f}")
275+
276+
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
277+
distance_wrt_target)
278+
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
279+
distance_wrt_reference)
280+
281+
expected_improvement_multiplier = 20
282+
assert (relative_change_in_distance_wrt_target >
283+
relative_change_in_distance_wrt_reference *
284+
expected_improvement_multiplier)
285+
286+
287+
def get_ratio_first_to_last(elements: List[float]) -> float:
288+
return elements[0] / elements[-1]
289+
290+
291+
class _CorrectnessTestHelper:
292+
"""Class that packages together logic required for the unit-level
293+
rejection sampling correctness test.
294+
"""
295+
296+
def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
297+
self.rejection_sampler = rejection_sampler
298+
self.vocab_size = vocab_size
299+
self.vocab_range = (0, vocab_size)
300+
301+
self.rejection_sampler.init_gpu_tensors(rank=0)
302+
303+
# Keep test simple, use k=1
304+
self.k = 1
305+
306+
# Bonus tokens not used, but rejection sampler requires
307+
# correct shape.
308+
self.num_bonus_tokens = 1
309+
310+
def generate_probs_for_test(
311+
self, draft_and_target_probs_equal: bool
312+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
313+
draft_probs, target_probs = [
314+
F.softmax(
315+
torch.rand(self.vocab_size, dtype=torch.float32),
316+
dim=-1,
317+
) for _ in range(2)
318+
]
319+
320+
num_reference_probs = 100
321+
reference_probs = F.softmax(
322+
torch.rand(num_reference_probs,
323+
self.vocab_size,
324+
dtype=torch.float32),
325+
dim=-1,
326+
)
327+
328+
if draft_and_target_probs_equal:
329+
target_probs = draft_probs.clone()
330+
331+
return draft_probs, target_probs, reference_probs
332+
333+
def run_and_compare_distributions(self, draft_probs: torch.Tensor,
334+
target_probs: torch.Tensor,
335+
reference_probs: torch.Tensor,
336+
num_samples: int) -> Tuple[float, float]:
337+
# Sample using rejection sampling.
338+
rej_sample_probs = self._estimate_rejection_sampling_pdf(
339+
draft_probs, target_probs, num_samples)
340+
341+
# Average distance from reference probs.
342+
reference_vs_rejsample_dist = torch.dist(
343+
reference_probs,
344+
rej_sample_probs).item() / reference_probs.shape[0]
345+
target_vs_rejsample_dist = torch.dist(target_probs,
346+
rej_sample_probs).item()
347+
348+
return reference_vs_rejsample_dist, target_vs_rejsample_dist
349+
350+
def _estimate_rejection_sampling_pdf(
351+
self,
352+
draft_probs: torch.Tensor,
353+
target_probs: torch.Tensor,
354+
num_samples: int,
355+
) -> torch.Tensor:
356+
# Repeat draft probs num_samples times.
357+
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
358+
num_samples, 1, 1)
359+
360+
# Repeat target probs num_samples * k times.
361+
# Rejection sampler requires bonus token probs, but they aren't used.
362+
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
363+
num_samples, self.k, 1)
364+
365+
# Randomly sample draft token ids from draft probs.
366+
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
367+
num_samples=1,
368+
replacement=True).reshape(
369+
num_samples, self.k)
370+
371+
# Bonus tokens not used but required.
372+
bonus_token_ids = torch.zeros((1, self.num_bonus_tokens),
373+
dtype=torch.int64,
374+
device="cuda").repeat(num_samples, 1)
375+
376+
# Get output tokens via rejection sampling.
377+
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
378+
bonus_token_ids.to("cuda"),
379+
draft_probs.to("cuda"),
380+
draft_token_ids.to("cuda"))
381+
382+
# Remove bonus tokens
383+
output_token_ids = output_token_ids[:, :-1].flatten()
384+
385+
# Estimate probability density function
386+
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
387+
device="cpu"),
388+
bins=self.vocab_size,
389+
range=self.vocab_range,
390+
density=True)
391+
392+
return hist.hist

0 commit comments

Comments
 (0)