|
| 1 | +"""Tests for rejection sampling.""" |
| 2 | +import pytest |
| 3 | +from typing import List, Tuple |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn.functional as F |
| 7 | + |
| 8 | +from vllm.model_executor.utils import set_random_seed |
| 9 | + |
| 10 | +from vllm.model_executor.layers.rejection_sampler import RejectionSampler |
| 11 | + |
| 12 | + |
| 13 | +def mock_causal_accepted_tensor( |
| 14 | + k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor: |
| 15 | + """Generate an "accepted" tensor which should yield causally-accepted tokens |
| 16 | + up to last accepted indices. |
| 17 | +
|
| 18 | + Tokens after last_accepted_indices+1 may also be accepted, although they |
| 19 | + will not be causally accepted. |
| 20 | + """ |
| 21 | + batch_size = last_accepted_indices.shape[0] |
| 22 | + |
| 23 | + accepted = (torch.arange(k).expand(batch_size, k) <= |
| 24 | + last_accepted_indices.unsqueeze(-1).broadcast_to( |
| 25 | + batch_size, k)).to(device="cuda") |
| 26 | + |
| 27 | + # Sprinkle accepted values after the contiguous initial accepted values. |
| 28 | + # This replicates the behavior of rejection sampling, which may "accept" |
| 29 | + # a token that cannot be accepted because of causality. |
| 30 | + sprinkle_candidates = ( |
| 31 | + torch.arange(k).expand(batch_size, k) > |
| 32 | + last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1) |
| 33 | + sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5 |
| 34 | + accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates] |
| 35 | + return accepted |
| 36 | + |
| 37 | + |
| 38 | +@pytest.mark.parametrize("seed", list(range(10))) |
| 39 | +@pytest.mark.parametrize( |
| 40 | + "which_tokens_accepted", |
| 41 | + ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) |
| 42 | +@torch.inference_mode() |
| 43 | +def test_correct_output_format(which_tokens_accepted: str, seed: int): |
| 44 | + """Verify the output has correct format given predetermined accepted matrix. |
| 45 | + """ |
| 46 | + set_random_seed(seed) |
| 47 | + |
| 48 | + batch_size = 10 |
| 49 | + k = 5 |
| 50 | + vocab_size = 3000 |
| 51 | + |
| 52 | + if which_tokens_accepted == "all_tokens_accepted": |
| 53 | + accepted = mock_causal_accepted_tensor( |
| 54 | + k, -1 + k * torch.ones((batch_size, ), dtype=torch.long)) |
| 55 | + elif which_tokens_accepted == "no_tokens_accepted": |
| 56 | + accepted = mock_causal_accepted_tensor( |
| 57 | + k, -torch.ones((batch_size, ), dtype=torch.long)) |
| 58 | + elif which_tokens_accepted == "some_tokens_accepted": |
| 59 | + last_accepted_indices = torch.randint(low=-1, |
| 60 | + high=k, |
| 61 | + size=(batch_size, )) |
| 62 | + accepted = mock_causal_accepted_tensor(k, last_accepted_indices) |
| 63 | + else: |
| 64 | + raise AssertionError() |
| 65 | + |
| 66 | + recovered_token_ids = torch.randint(low=0, |
| 67 | + high=vocab_size, |
| 68 | + size=(batch_size, k), |
| 69 | + dtype=torch.int64, |
| 70 | + device="cuda") |
| 71 | + draft_token_ids = torch.randint(low=0, |
| 72 | + high=vocab_size, |
| 73 | + size=(batch_size, k), |
| 74 | + dtype=torch.int64, |
| 75 | + device="cuda") |
| 76 | + bonus_token_ids = torch.randint(low=0, |
| 77 | + high=vocab_size, |
| 78 | + size=(batch_size, 1), |
| 79 | + dtype=torch.int64, |
| 80 | + device="cuda") |
| 81 | + |
| 82 | + rejection_sampler = RejectionSampler() |
| 83 | + rejection_sampler.init_gpu_tensors(rank=0) |
| 84 | + output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access |
| 85 | + accepted, |
| 86 | + recovered_token_ids, |
| 87 | + draft_token_ids, |
| 88 | + bonus_token_ids, |
| 89 | + ) |
| 90 | + |
| 91 | + if which_tokens_accepted == "all_tokens_accepted": |
| 92 | + # Expect all tokens to be equal to draft tokens. |
| 93 | + assert torch.equal(output_token_ids[:, :-1], draft_token_ids) |
| 94 | + |
| 95 | + # Expect all bonus tokens to be included. |
| 96 | + assert torch.equal(output_token_ids[:, -1:], bonus_token_ids) |
| 97 | + elif which_tokens_accepted == "no_tokens_accepted": |
| 98 | + # Expect first token to be equal to recovered tokens. |
| 99 | + assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0]) |
| 100 | + |
| 101 | + # Expect everything else to be -1. |
| 102 | + assert torch.equal(output_token_ids[:, 1:], |
| 103 | + torch.ones_like(output_token_ids[:, 1:]) * -1) |
| 104 | + elif which_tokens_accepted == "some_tokens_accepted": |
| 105 | + recovered_plus_bonus = torch.cat( |
| 106 | + (recovered_token_ids, bonus_token_ids), dim=-1) |
| 107 | + # Assert first rejected token is a recovered token or bonus token. |
| 108 | + assert torch.equal( |
| 109 | + recovered_plus_bonus[torch.arange(0, batch_size), |
| 110 | + last_accepted_indices + 1], |
| 111 | + output_token_ids[torch.arange(0, batch_size), |
| 112 | + last_accepted_indices + 1]) |
| 113 | + |
| 114 | + # Assert every subsequent token is -1. |
| 115 | + subsequent_mask = torch.arange(0, k + 1).expand( |
| 116 | + batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1) |
| 117 | + assert torch.all(output_token_ids[subsequent_mask] == -1) |
| 118 | + |
| 119 | + |
| 120 | +@pytest.mark.parametrize("k", list(range(1, 6))) |
| 121 | +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) |
| 122 | +@pytest.mark.parametrize("batch_size", list(range(1, 32))) |
| 123 | +@torch.inference_mode() |
| 124 | +def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int): |
| 125 | + rejection_sampler = RejectionSampler() |
| 126 | + rejection_sampler.init_gpu_tensors(rank=0) |
| 127 | + |
| 128 | + draft_probs = torch.rand(batch_size, |
| 129 | + k, |
| 130 | + vocab_size, |
| 131 | + dtype=torch.float32, |
| 132 | + device="cuda") |
| 133 | + target_probs = torch.rand(batch_size, |
| 134 | + k, |
| 135 | + vocab_size, |
| 136 | + dtype=torch.float32, |
| 137 | + device="cuda") |
| 138 | + bonus_token_ids = torch.randint(low=0, |
| 139 | + high=vocab_size, |
| 140 | + size=(batch_size, 1), |
| 141 | + dtype=torch.int64, |
| 142 | + device="cuda") |
| 143 | + draft_token_ids = torch.randint(low=0, |
| 144 | + high=vocab_size, |
| 145 | + size=(batch_size, k), |
| 146 | + dtype=torch.int64, |
| 147 | + device="cuda") |
| 148 | + |
| 149 | + rejection_sampler(target_probs, bonus_token_ids, draft_probs, |
| 150 | + draft_token_ids) |
| 151 | + |
| 152 | + |
| 153 | +@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) |
| 154 | +@pytest.mark.parametrize("which_token_ids", |
| 155 | + ["bonus_token_ids", "draft_token_ids"]) |
| 156 | +@torch.inference_mode() |
| 157 | +def test_raises_when_vocab_oob(above_or_below_vocab_range: str, |
| 158 | + which_token_ids: str): |
| 159 | + k = 3 |
| 160 | + batch_size = 5 |
| 161 | + vocab_size = 30_000 |
| 162 | + |
| 163 | + rejection_sampler = RejectionSampler(strict_mode=True) |
| 164 | + rejection_sampler.init_gpu_tensors(rank=0) |
| 165 | + |
| 166 | + draft_probs = torch.rand(batch_size, |
| 167 | + k, |
| 168 | + vocab_size, |
| 169 | + dtype=torch.float32, |
| 170 | + device="cuda") |
| 171 | + target_probs = torch.rand(batch_size, |
| 172 | + k, |
| 173 | + vocab_size, |
| 174 | + dtype=torch.float32, |
| 175 | + device="cuda") |
| 176 | + bonus_token_ids = torch.randint(low=0, |
| 177 | + high=vocab_size, |
| 178 | + size=(batch_size, 1), |
| 179 | + dtype=torch.int64, |
| 180 | + device="cuda") |
| 181 | + draft_token_ids = torch.randint(low=0, |
| 182 | + high=vocab_size, |
| 183 | + size=(batch_size, k), |
| 184 | + dtype=torch.int64, |
| 185 | + device="cuda") |
| 186 | + |
| 187 | + oob_token_ids = None |
| 188 | + if which_token_ids == "bonus_token_ids": |
| 189 | + oob_token_ids = bonus_token_ids |
| 190 | + elif which_token_ids == "draft_token_ids": |
| 191 | + oob_token_ids = draft_token_ids |
| 192 | + else: |
| 193 | + raise AssertionError() |
| 194 | + |
| 195 | + if above_or_below_vocab_range == "above": |
| 196 | + rogue_token_id = vocab_size + 1 |
| 197 | + elif above_or_below_vocab_range == "below": |
| 198 | + rogue_token_id = -1 |
| 199 | + else: |
| 200 | + raise AssertionError() |
| 201 | + |
| 202 | + oob_token_ids[0][0] = rogue_token_id |
| 203 | + |
| 204 | + with pytest.raises(AssertionError): |
| 205 | + rejection_sampler(target_probs, bonus_token_ids, draft_probs, |
| 206 | + draft_token_ids) |
| 207 | + |
| 208 | + |
| 209 | +@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) |
| 210 | +@pytest.mark.parametrize("seed", list(range(5))) |
| 211 | +@torch.inference_mode() |
| 212 | +def test_rejection_sampling_approximates_target_distribution( |
| 213 | + seed: int, draft_and_target_probs_equal: bool): |
| 214 | + """Verify rejection sampling approximates target distribution, |
| 215 | + despite sampling from a potentially distinct draft distribution. |
| 216 | +
|
| 217 | + This is done by first creating a random target probability |
| 218 | + distribution and a random draft probability distribution. We then |
| 219 | + sample token ids from the rejection sampler using these draft |
| 220 | + and target distributions. The samples are used to estimate |
| 221 | + the output probability distribution, which we expect to approximate |
| 222 | + the target distribution. |
| 223 | +
|
| 224 | + A basic distance metric is used to determine similarity between |
| 225 | + distributions. |
| 226 | +
|
| 227 | + We expect that as we increase the number of samples, |
| 228 | + the distance between the observed distribution and the target |
| 229 | + distribution decreases. To measure this, we compare the distance |
| 230 | + of the observed distribution against both the target distribution |
| 231 | + and a uniform random distribution. We expect the distance between |
| 232 | + the observed distribution and the target distribution to improve |
| 233 | + much more than the distance improvement between the observed |
| 234 | + distribution and the random distribution. |
| 235 | +
|
| 236 | + When draft_and_target_probs_equal=True, the draft and target |
| 237 | + probabilities are exactly equal. Rejection sampling should |
| 238 | + still work without any NaNs or exceptions. |
| 239 | + """ |
| 240 | + set_random_seed(seed) |
| 241 | + |
| 242 | + helper = _CorrectnessTestHelper( |
| 243 | + vocab_size=10, |
| 244 | + rejection_sampler=RejectionSampler(), |
| 245 | + ) |
| 246 | + |
| 247 | + draft_probs, target_probs, reference_probs = helper.generate_probs_for_test( |
| 248 | + draft_and_target_probs_equal) |
| 249 | + |
| 250 | + sample_sizes = [10, 100, 1_000, 10_000, 100_000] |
| 251 | + distance_wrt_reference = [] |
| 252 | + distance_wrt_target = [] |
| 253 | + |
| 254 | + for num_samples in sample_sizes: |
| 255 | + (reference_vs_rejsample_dist, |
| 256 | + target_vs_rejsample_dist) = helper.run_and_compare_distributions( |
| 257 | + draft_probs, |
| 258 | + target_probs, |
| 259 | + reference_probs, |
| 260 | + num_samples, |
| 261 | + ) |
| 262 | + |
| 263 | + distance_wrt_reference.append(reference_vs_rejsample_dist) |
| 264 | + distance_wrt_target.append(target_vs_rejsample_dist) |
| 265 | + |
| 266 | + relative_change_in_distance_wrt_target = get_ratio_first_to_last( |
| 267 | + distance_wrt_target) |
| 268 | + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( |
| 269 | + distance_wrt_reference) |
| 270 | + |
| 271 | + print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} " |
| 272 | + f"{reference_vs_rejsample_dist=:.05f}") |
| 273 | + print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " |
| 274 | + f"{relative_change_in_distance_wrt_reference=:.02f}") |
| 275 | + |
| 276 | + relative_change_in_distance_wrt_target = get_ratio_first_to_last( |
| 277 | + distance_wrt_target) |
| 278 | + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( |
| 279 | + distance_wrt_reference) |
| 280 | + |
| 281 | + expected_improvement_multiplier = 20 |
| 282 | + assert (relative_change_in_distance_wrt_target > |
| 283 | + relative_change_in_distance_wrt_reference * |
| 284 | + expected_improvement_multiplier) |
| 285 | + |
| 286 | + |
| 287 | +def get_ratio_first_to_last(elements: List[float]) -> float: |
| 288 | + return elements[0] / elements[-1] |
| 289 | + |
| 290 | + |
| 291 | +class _CorrectnessTestHelper: |
| 292 | + """Class that packages together logic required for the unit-level |
| 293 | + rejection sampling correctness test. |
| 294 | + """ |
| 295 | + |
| 296 | + def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler): |
| 297 | + self.rejection_sampler = rejection_sampler |
| 298 | + self.vocab_size = vocab_size |
| 299 | + self.vocab_range = (0, vocab_size) |
| 300 | + |
| 301 | + self.rejection_sampler.init_gpu_tensors(rank=0) |
| 302 | + |
| 303 | + # Keep test simple, use k=1 |
| 304 | + self.k = 1 |
| 305 | + |
| 306 | + # Bonus tokens not used, but rejection sampler requires |
| 307 | + # correct shape. |
| 308 | + self.num_bonus_tokens = 1 |
| 309 | + |
| 310 | + def generate_probs_for_test( |
| 311 | + self, draft_and_target_probs_equal: bool |
| 312 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 313 | + draft_probs, target_probs = [ |
| 314 | + F.softmax( |
| 315 | + torch.rand(self.vocab_size, dtype=torch.float32), |
| 316 | + dim=-1, |
| 317 | + ) for _ in range(2) |
| 318 | + ] |
| 319 | + |
| 320 | + num_reference_probs = 100 |
| 321 | + reference_probs = F.softmax( |
| 322 | + torch.rand(num_reference_probs, |
| 323 | + self.vocab_size, |
| 324 | + dtype=torch.float32), |
| 325 | + dim=-1, |
| 326 | + ) |
| 327 | + |
| 328 | + if draft_and_target_probs_equal: |
| 329 | + target_probs = draft_probs.clone() |
| 330 | + |
| 331 | + return draft_probs, target_probs, reference_probs |
| 332 | + |
| 333 | + def run_and_compare_distributions(self, draft_probs: torch.Tensor, |
| 334 | + target_probs: torch.Tensor, |
| 335 | + reference_probs: torch.Tensor, |
| 336 | + num_samples: int) -> Tuple[float, float]: |
| 337 | + # Sample using rejection sampling. |
| 338 | + rej_sample_probs = self._estimate_rejection_sampling_pdf( |
| 339 | + draft_probs, target_probs, num_samples) |
| 340 | + |
| 341 | + # Average distance from reference probs. |
| 342 | + reference_vs_rejsample_dist = torch.dist( |
| 343 | + reference_probs, |
| 344 | + rej_sample_probs).item() / reference_probs.shape[0] |
| 345 | + target_vs_rejsample_dist = torch.dist(target_probs, |
| 346 | + rej_sample_probs).item() |
| 347 | + |
| 348 | + return reference_vs_rejsample_dist, target_vs_rejsample_dist |
| 349 | + |
| 350 | + def _estimate_rejection_sampling_pdf( |
| 351 | + self, |
| 352 | + draft_probs: torch.Tensor, |
| 353 | + target_probs: torch.Tensor, |
| 354 | + num_samples: int, |
| 355 | + ) -> torch.Tensor: |
| 356 | + # Repeat draft probs num_samples times. |
| 357 | + draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat( |
| 358 | + num_samples, 1, 1) |
| 359 | + |
| 360 | + # Repeat target probs num_samples * k times. |
| 361 | + # Rejection sampler requires bonus token probs, but they aren't used. |
| 362 | + target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat( |
| 363 | + num_samples, self.k, 1) |
| 364 | + |
| 365 | + # Randomly sample draft token ids from draft probs. |
| 366 | + draft_token_ids = torch.multinomial(draft_probs[:, 0, :], |
| 367 | + num_samples=1, |
| 368 | + replacement=True).reshape( |
| 369 | + num_samples, self.k) |
| 370 | + |
| 371 | + # Bonus tokens not used but required. |
| 372 | + bonus_token_ids = torch.zeros((1, self.num_bonus_tokens), |
| 373 | + dtype=torch.int64, |
| 374 | + device="cuda").repeat(num_samples, 1) |
| 375 | + |
| 376 | + # Get output tokens via rejection sampling. |
| 377 | + output_token_ids = self.rejection_sampler(target_probs.to("cuda"), |
| 378 | + bonus_token_ids.to("cuda"), |
| 379 | + draft_probs.to("cuda"), |
| 380 | + draft_token_ids.to("cuda")) |
| 381 | + |
| 382 | + # Remove bonus tokens |
| 383 | + output_token_ids = output_token_ids[:, :-1].flatten() |
| 384 | + |
| 385 | + # Estimate probability density function |
| 386 | + hist = torch.histogram(output_token_ids.to(dtype=torch.float, |
| 387 | + device="cpu"), |
| 388 | + bins=self.vocab_size, |
| 389 | + range=self.vocab_range, |
| 390 | + density=True) |
| 391 | + |
| 392 | + return hist.hist |
0 commit comments