Skip to content

Commit f710fb5

Browse files
peng1999mgoin
andauthored
[Core] Use flashinfer sampling kernel when available (#7137)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
1 parent ff7ec82 commit f710fb5

File tree

5 files changed

+130
-28
lines changed

5 files changed

+130
-28
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ steps:
192192
- vllm/model_executor/layers
193193
- vllm/sampling_metadata.py
194194
- tests/samplers
195-
command: pytest -v -s samplers
195+
commands:
196+
- pytest -v -s samplers
197+
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
196198

197199
- label: LogitsProcessor Test # 5min
198200
mirror_hardwares: [amd]

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
194194
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
195195

196196
RUN --mount=type=cache,target=/root/.cache/pip \
197-
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl
197+
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp310-cp310-linux_x86_64.whl
198198
#################### vLLM installation IMAGE ####################
199199

200200

tests/samplers/test_sampler.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from transformers import GenerationConfig, GenerationMixin
1010

11+
import vllm.envs as envs
1112
from vllm.model_executor.layers.sampler import Sampler
1213
from vllm.model_executor.sampling_metadata import SamplingMetadata
1314
from vllm.model_executor.utils import set_random_seed
@@ -634,7 +635,10 @@ def mock_sample(probs, *args, **kwargs):
634635
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
635636
for prob in probs], None)
636637

637-
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
638+
# top-k and top-p is only calculated when flashinfer kernel is not available
639+
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
640+
patch("vllm.model_executor.layers.sampler."
641+
"flashinfer_top_k_top_p_sampling", None):
638642
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
639643

640644
assert sample_probs is not None
@@ -645,6 +649,37 @@ def mock_sample(probs, *args, **kwargs):
645649
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
646650

647651

652+
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
653+
@pytest.mark.parametrize("device", CUDA_DEVICES)
654+
def test_flashinfer_fallback(seed: int, device: str):
655+
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
656+
pytest.skip("Flashinfer sampler is disabled")
657+
658+
set_random_seed(seed)
659+
torch.set_default_device(device)
660+
batch_size = random.randint(1, 256)
661+
_, fake_logits, sampler = _prepare_test(batch_size)
662+
663+
def failing_flashinfer_sampling(*_args, **_kwargs):
664+
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)
665+
666+
sampling_params = SamplingParams(
667+
temperature=1.0,
668+
n=random.randint(1, 10),
669+
seed=random.randint(0, 10000),
670+
)
671+
sampler_output = _do_sample(batch_size, fake_logits, sampler,
672+
sampling_params, device)
673+
674+
with patch(
675+
"vllm.model_executor.layers.sampler."
676+
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
677+
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
678+
sampling_params, device)
679+
680+
assert sampler_output == fallback_sampler_output
681+
682+
648683
@pytest.mark.parametrize("device", CUDA_DEVICES)
649684
def test_sampler_repetition_penalty_mixed(device: str):
650685

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
3131
VLLM_TRACE_FUNCTION: int = 0
3232
VLLM_ATTENTION_BACKEND: Optional[str] = None
33+
VLLM_USE_FLASHINFER_SAMPLER: bool = False
3334
VLLM_PP_LAYER_PARTITION: Optional[str] = None
3435
VLLM_CPU_KVCACHE_SPACE: int = 0
3536
VLLM_CPU_OMP_THREADS_BIND: str = ""
@@ -256,6 +257,10 @@ def get_default_config_root():
256257
"VLLM_ATTENTION_BACKEND":
257258
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
258259

260+
# If set, vllm will use flashinfer sampler
261+
"VLLM_USE_FLASHINFER_SAMPLER":
262+
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
263+
259264
# Pipeline stage partition strategy
260265
"VLLM_PP_LAYER_PARTITION":
261266
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),

vllm/model_executor/layers/sampler.py

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""A layer that samples the next tokens from the model's outputs."""
22
import itertools
3+
import warnings
4+
from importlib.util import find_spec
35
from math import inf
46
from typing import Dict, List, Optional, Tuple
57

@@ -11,6 +13,7 @@
1113
if HAS_TRITON:
1214
from vllm.model_executor.layers.ops.sample import sample as sample_triton
1315

16+
import vllm.envs as envs
1417
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
1518
SamplingTensors,
1619
SequenceGroupToSample)
@@ -19,6 +22,16 @@
1922
PromptLogprobs, SampleLogprobs, SamplerOutput,
2023
SequenceOutput)
2124

25+
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
26+
import flashinfer.sampling
27+
# yapf: disable
28+
from flashinfer.sampling import (
29+
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
30+
31+
# yapf: enable
32+
else:
33+
flashinfer_top_k_top_p_sampling = None
34+
2235
# (num_token_ids, num_parent_ids) per sequence group.
2336
SampleResultType = List[Tuple[List[int], List[int]]]
2437

@@ -123,7 +136,7 @@ def forward(
123136
logits = logits.to(torch.float)
124137
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
125138

126-
if do_top_p_top_k:
139+
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
127140
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
128141
sampling_tensors.top_ks)
129142

@@ -476,32 +489,65 @@ def _multinomial(
476489
seq_groups: Optional[List[SequenceGroupToSample]] = None,
477490
) -> torch.Tensor:
478491
if num_samples > 1:
479-
# This is equivalent to torch.repeat_interleaved (which also
480-
# forces a GPU<->CPU sync).
481-
# This allows us to do sampling with replacement by creating
482-
# num_samples copies of each row in the tensor, and then
483-
# batch sampling the resulting tensor.
484-
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
485-
probs.shape[1]).contiguous().view(
486-
-1, probs.shape[1])
492+
probs = probs.repeat_interleave(num_samples, dim=0)
487493
q = torch.empty_like(probs)
488494
if seq_groups is None:
489495
q.exponential_()
490496
else:
491497
sample_idx = 0
492498
for seq_group in seq_groups:
493499
seq_ids = seq_group.seq_ids
494-
next_sample_idx = sample_idx + len(seq_ids) * num_samples
495-
q[sample_idx:next_sample_idx].exponential_(
496-
generator=seq_group.generator)
497-
sample_idx = next_sample_idx
500+
stride = len(seq_ids) * num_samples
501+
assert seq_group.generator is not None
502+
q[sample_idx:sample_idx +
503+
stride].exponential_(generator=seq_group.generator)
504+
sample_idx += stride
498505
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
499506

500507

508+
def _top_k_top_p_multinomial_with_flashinfer(
509+
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
510+
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
511+
max_top_k_round = 32
512+
if num_samples > 1:
513+
probs = probs.repeat_interleave(num_samples, dim=0)
514+
top_ks = top_ks.repeat_interleave(num_samples)
515+
top_ps = top_ps.repeat_interleave(num_samples)
516+
batch_size = probs.shape[0]
517+
uniform_samples = torch.empty((max_top_k_round, batch_size),
518+
device=probs.device)
519+
if seq_groups is None:
520+
uniform_samples.uniform_()
521+
else:
522+
sample_idx = 0
523+
for seq_group in seq_groups:
524+
seq_ids = seq_group.seq_ids
525+
stride = len(seq_ids) * num_samples
526+
assert seq_group.generator is not None
527+
uniform_samples[:, sample_idx:sample_idx +
528+
stride].uniform_(generator=seq_group.generator)
529+
sample_idx += stride
530+
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
531+
probs,
532+
uniform_samples,
533+
top_ks,
534+
top_ps,
535+
)
536+
if not success.all():
537+
warnings.warn("FlashInfer rejection sampling failed, fallback.",
538+
stacklevel=1)
539+
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
540+
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
541+
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
542+
probs, uniform_samples[0])
543+
return batch_next_token_ids.view(-1, num_samples)
544+
545+
501546
def _sample_with_torch(
502547
probs: torch.Tensor,
503548
logprobs: torch.Tensor,
504549
sampling_metadata: SamplingMetadata,
550+
sampling_tensors: SamplingTensors,
505551
include_gpu_probs_tensor: bool,
506552
modify_greedy_probs: bool,
507553
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
@@ -564,18 +610,28 @@ def _sample_with_torch(
564610
sampling_params = seq_group.sampling_params
565611
max_best_of_in_batch = max(max_best_of_in_batch,
566612
sampling_params.best_of)
567-
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
568-
"seq_groups": seq_groups,
569-
}
570-
571-
multinomial_samples[sampling_type] = _multinomial(
572-
probs[long_sample_indices], max_best_of_in_batch,
573-
**seeded_args)
613+
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
614+
seq_groups)
615+
616+
if flashinfer_top_k_top_p_sampling is not None:
617+
multinomial_samples[
618+
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
619+
probs[long_sample_indices],
620+
sampling_tensors.top_ks[long_sample_indices],
621+
sampling_tensors.top_ps[long_sample_indices],
622+
max_best_of_in_batch,
623+
seq_groups_arg,
624+
)
625+
else:
626+
multinomial_samples[sampling_type] = _multinomial(
627+
probs[long_sample_indices],
628+
max_best_of_in_batch,
629+
seq_groups=seq_groups_arg)
574630

575631
if sampled_token_ids_tensor is not None:
576632
# Store sampled tokens in output tensor.
577-
sampled_token_ids_tensor[
578-
long_sample_indices] = multinomial_samples[sampling_type]
633+
sampled_token_ids_tensor[long_sample_indices] = \
634+
multinomial_samples[sampling_type].to(torch.long)
579635

580636
elif sampling_type == SamplingType.BEAM:
581637
beam_search_logprobs = logprobs[sample_indices]
@@ -693,9 +749,12 @@ def _sample_with_triton_kernel(
693749

694750

695751
def _sample(
696-
probs: torch.Tensor, logprobs: torch.Tensor,
697-
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
698-
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
752+
probs: torch.Tensor,
753+
logprobs: torch.Tensor,
754+
sampling_metadata: SamplingMetadata,
755+
sampling_tensors: SamplingTensors,
756+
include_gpu_probs_tensor: bool,
757+
modify_greedy_probs: bool,
699758
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
700759
"""
701760
Args:
@@ -713,6 +772,7 @@ def _sample(
713772
probs,
714773
logprobs,
715774
sampling_metadata,
775+
sampling_tensors,
716776
include_gpu_probs_tensor=include_gpu_probs_tensor,
717777
modify_greedy_probs=modify_greedy_probs,
718778
)

0 commit comments

Comments
 (0)