Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SpecDecode][Kernel] Use Flashinfer for Rejection Sampling in Speculative Decoding #7244

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix tests and comments
  • Loading branch information
LiuXiaoxuanPKU committed Aug 12, 2024
commit c9f88d956c100c030e42a45addb1cf50d20a5e4b
50 changes: 32 additions & 18 deletions tests/samplers/test_typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
Expand All @@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
size=(batch_size, k),
dtype=torch.int64)
# Verify that sampling succeeds for all cases.
typical_acceptance_sampler(target_probs,
typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
Expand All @@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
Expand Down Expand Up @@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
oob_token_ids[0][0] = rogue_token_id

with pytest.raises(AssertionError):
typical_acceptance_sampler(target_probs,
typical_acceptance_sampler(target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
Expand Down Expand Up @@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens(
typical_acceptance_sampler = get_acceptance_sampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
dtype=torch.float32)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
Expand All @@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
Expand Down Expand Up @@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int,
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
batch_size, k, vocab_size)
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
# Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
Expand All @@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int,
# fallback to the greedy sampling for selecting 1 token for each sequence.
# Verify the same.
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
Expand Down Expand Up @@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs = target_with_bonus_probs[:, :-1]
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32)
Expand All @@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
Expand Down Expand Up @@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_with_bonus_probs, zero_temperature_token_ids = \
get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
draft_token_ids = zero_temperature_token_ids
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
Expand All @@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
draft_token_ids = torch.cat(
(draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1)
output_token_ids = typical_acceptance_sampler(
target_probs,
target_with_bonus_probs,
bonus_token_ids,
draft_probs=None,
draft_token_ids=draft_token_ids)
Expand Down Expand Up @@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0. Without any changes to the posterior thresholds
# none of the draft tokens are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist(
batch_size, k + 1, vocab_size)
zero_temperature_token_ids = zero_temperature_token_ids[:, :-1]
target_probs[target_probs == 0] = 0.00001
draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size,
zero_temperature_token_ids)
Expand Down
2 changes: 1 addition & 1 deletion tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
assert torch.equal(actual.bonus_token_ids,
target_token_ids.reshape(batch_size, k + 1)[:, -1:])
assert torch.equal(
actual.target_probs,
actual.target_with_bonus_probs[:, :-1],
target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
assert torch.equal(actual.draft_token_ids, proposal_token_ids)
assert torch.equal(actual.draft_probs, proposal_probs)
Expand Down
97 changes: 70 additions & 27 deletions vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,21 @@
import torch
import torch.jit

from vllm.logger import init_logger
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeStochasticBaseSampler)

logger = init_logger(__name__)

try:
"""
Consider utilizing the FlashInfer rejection sampling kernel initially,
as it employs a dedicated kernel rather than relying on
Torch tensor operations. This design choice helps to fuse operations,
reduce memory I/O, and consequently enhances performance.
"""
from flashinfer.sampling import chain_speculative_sampling
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved Hide resolved
logger.info("Use flashinfer for rejection sampling.")
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved Hide resolved
except ImportError:
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved Hide resolved
chain_speculative_sampling = None
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -55,7 +65,7 @@ def forward(
sequence.

Args:
target_probs_with_bonus_probs: The probability distribution
target_with_bonus_probs: The probability distribution
over token ids given context according to the target model.
shape = [batch_size, num_speculative_tokens + 1, vocab_size]

Expand All @@ -82,12 +92,17 @@ def forward(
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
target_probs = target_with_bonus_probs[:, :-1]
if self._strict_mode:
self._raise_if_incorrect_input(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_incorrect_input(target_with_bonus_probs,
draft_token_ids, bonus_token_ids,
draft_probs)

batch_size, k, _ = draft_probs.shape

if chain_speculative_sampling:
if batch_size == 0:
LiuXiaoxuanPKU marked this conversation as resolved.
Show resolved Hide resolved
return torch.empty(0, k + 1, device=draft_probs.device, dtype=int)

if chain_speculative_sampling is not None:
batch_size, k, _ = draft_probs.shape
uniform_samples = self._create_uniform_samples(
seeded_seqs, batch_size, k, draft_probs.device)
Expand All @@ -97,7 +112,7 @@ def forward(
else:
accepted, recovered_token_ids = (
self._batch_modified_rejection_sampling(
target_probs,
target_with_bonus_probs[:, :-1],
draft_probs,
draft_token_ids,
seeded_seqs,
Expand Down Expand Up @@ -154,28 +169,56 @@ def _create_uniform_samples(self,
torch.Generator]],
batch_size: int, k: int,
device: torch.device) -> torch.Tensor:
"""
Generates a batch of uniform random samples, with optional seeding
for specific sequences.

This method creates a tensor of shape `(batch_size, k + 1)` filled
with uniform random values in the range [0, 1). If `seeded_seqs`
is provided, the sequences corresponding to specific indices
will be generated using the provided `torch.Generator` for
reproducibility. The other sequences will be generated without
a seed.

Args:
seeded_seqs : Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects. If `None`, all samples are
generated without a seed.
batch_size : int
The number of sequences to generate.
k : int
The number of random samples per sequence.
device : torch.device
The device on which to allocate the tensor.

Returns:
uniform_rand : torch.Tensor
A tensor of shape `(batch_size, k + 1)` containing uniform
random values in the range [0, 1).
"""
if not seeded_seqs:
uniform_rand = torch.rand(batch_size, k + 1, device=device)
else:
uniform_rand = torch.empty(batch_size, k + 1, device=device)

non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(1,
k + 1,
dtype=self.probs_dtype,
device=device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k + 1,
dtype=self.probs_dtype,
device=device)
return torch.rand(batch_size, k + 1, device=device)

uniform_rand = torch.empty(batch_size, k + 1, device=device)

non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(1,
k + 1,
dtype=self.probs_dtype,
device=device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k + 1,
dtype=self.probs_dtype,
device=device)
return uniform_rand

def _get_accepted(
Expand Down
Loading