Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
cadedaniel committed Jan 5, 2024
1 parent 9bb7962 commit 876bf2e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int):
size=(batch_size, ))
accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
else:
assert False
raise AssertionError()

recovered_token_ids = torch.randint(low=0,
high=vocab_size,
Expand Down Expand Up @@ -190,14 +190,14 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
elif which_token_ids == "draft_token_ids":
oob_token_ids = draft_token_ids
else:
assert False
raise AssertionError()

if above_or_below_vocab_range == "above":
rogue_token_id = vocab_size + 1
elif above_or_below_vocab_range == "below":
rogue_token_id = -1
else:
assert False
raise AssertionError()

oob_token_ids[0][0] = rogue_token_id

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def _raise_if_out_of_bounds_vocab(
assert torch.all(draft_token_ids < vocab_size)
assert torch.all(draft_token_ids >= 0)


# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
Expand All @@ -385,4 +386,3 @@ def _multinomial(
-1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1.0)
return probs.div_(q).argmax(dim=1).view(-1, num_samples)

0 comments on commit 876bf2e

Please sign in to comment.