11"""A layer that samples the next tokens from the model's outputs."""
22import itertools
3+ import warnings
4+ from importlib .util import find_spec
35from math import inf
46from typing import Dict , List , Optional , Tuple
57
1113if HAS_TRITON :
1214 from vllm .model_executor .layers .ops .sample import sample as sample_triton
1315
16+ import vllm .envs as envs
1417from vllm .model_executor .sampling_metadata import (SamplingMetadata ,
1518 SamplingTensors ,
1619 SequenceGroupToSample )
1922 PromptLogprobs , SampleLogprobs , SamplerOutput ,
2023 SequenceOutput )
2124
25+ if envs .VLLM_USE_FLASHINFER_SAMPLER and find_spec ("flashinfer" ):
26+ import flashinfer .sampling
27+ # yapf: disable
28+ from flashinfer .sampling import (
29+ top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling )
30+
31+ # yapf: enable
32+ else :
33+ flashinfer_top_k_top_p_sampling = None
34+
2235# (num_token_ids, num_parent_ids) per sequence group.
2336SampleResultType = List [Tuple [List [int ], List [int ]]]
2437
@@ -123,7 +136,7 @@ def forward(
123136 logits = logits .to (torch .float )
124137 logits .div_ (sampling_tensors .temperatures .unsqueeze (dim = 1 ))
125138
126- if do_top_p_top_k :
139+ if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None :
127140 logits = _apply_top_k_top_p (logits , sampling_tensors .top_ps ,
128141 sampling_tensors .top_ks )
129142
@@ -476,32 +489,65 @@ def _multinomial(
476489 seq_groups : Optional [List [SequenceGroupToSample ]] = None ,
477490) -> torch .Tensor :
478491 if num_samples > 1 :
479- # This is equivalent to torch.repeat_interleaved (which also
480- # forces a GPU<->CPU sync).
481- # This allows us to do sampling with replacement by creating
482- # num_samples copies of each row in the tensor, and then
483- # batch sampling the resulting tensor.
484- probs = probs [:, None , :].expand (probs .shape [0 ], num_samples ,
485- probs .shape [1 ]).contiguous ().view (
486- - 1 , probs .shape [1 ])
492+ probs = probs .repeat_interleave (num_samples , dim = 0 )
487493 q = torch .empty_like (probs )
488494 if seq_groups is None :
489495 q .exponential_ ()
490496 else :
491497 sample_idx = 0
492498 for seq_group in seq_groups :
493499 seq_ids = seq_group .seq_ids
494- next_sample_idx = sample_idx + len (seq_ids ) * num_samples
495- q [sample_idx :next_sample_idx ].exponential_ (
496- generator = seq_group .generator )
497- sample_idx = next_sample_idx
500+ stride = len (seq_ids ) * num_samples
501+ assert seq_group .generator is not None
502+ q [sample_idx :sample_idx +
503+ stride ].exponential_ (generator = seq_group .generator )
504+ sample_idx += stride
498505 return probs .div_ (q ).argmax (dim = 1 ).view (- 1 , num_samples )
499506
500507
508+ def _top_k_top_p_multinomial_with_flashinfer (
509+ probs : torch .Tensor , top_ks : torch .Tensor , top_ps : torch .Tensor ,
510+ num_samples : int , seq_groups : Optional [List [SequenceGroupToSample ]]):
511+ max_top_k_round = 32
512+ if num_samples > 1 :
513+ probs = probs .repeat_interleave (num_samples , dim = 0 )
514+ top_ks = top_ks .repeat_interleave (num_samples )
515+ top_ps = top_ps .repeat_interleave (num_samples )
516+ batch_size = probs .shape [0 ]
517+ uniform_samples = torch .empty ((max_top_k_round , batch_size ),
518+ device = probs .device )
519+ if seq_groups is None :
520+ uniform_samples .uniform_ ()
521+ else :
522+ sample_idx = 0
523+ for seq_group in seq_groups :
524+ seq_ids = seq_group .seq_ids
525+ stride = len (seq_ids ) * num_samples
526+ assert seq_group .generator is not None
527+ uniform_samples [:, sample_idx :sample_idx +
528+ stride ].uniform_ (generator = seq_group .generator )
529+ sample_idx += stride
530+ batch_next_token_ids , success = flashinfer_top_k_top_p_sampling (
531+ probs ,
532+ uniform_samples ,
533+ top_ks ,
534+ top_ps ,
535+ )
536+ if not success .all ():
537+ warnings .warn ("FlashInfer rejection sampling failed, fallback." ,
538+ stacklevel = 1 )
539+ probs = flashinfer .sampling .top_k_renorm_prob (probs , top_ks )
540+ probs = flashinfer .sampling .top_p_renorm_prob (probs , top_ps )
541+ batch_next_token_ids = flashinfer .sampling .sampling_from_probs (
542+ probs , uniform_samples [0 ])
543+ return batch_next_token_ids .view (- 1 , num_samples )
544+
545+
501546def _sample_with_torch (
502547 probs : torch .Tensor ,
503548 logprobs : torch .Tensor ,
504549 sampling_metadata : SamplingMetadata ,
550+ sampling_tensors : SamplingTensors ,
505551 include_gpu_probs_tensor : bool ,
506552 modify_greedy_probs : bool ,
507553) -> Tuple [SampleResultType , Optional [torch .Tensor ]]:
@@ -564,18 +610,28 @@ def _sample_with_torch(
564610 sampling_params = seq_group .sampling_params
565611 max_best_of_in_batch = max (max_best_of_in_batch ,
566612 sampling_params .best_of )
567- seeded_args = {} if sampling_type == SamplingType .RANDOM else {
568- "seq_groups" : seq_groups ,
569- }
570-
571- multinomial_samples [sampling_type ] = _multinomial (
572- probs [long_sample_indices ], max_best_of_in_batch ,
573- ** seeded_args )
613+ seq_groups_arg = (None if sampling_type == SamplingType .RANDOM else
614+ seq_groups )
615+
616+ if flashinfer_top_k_top_p_sampling is not None :
617+ multinomial_samples [
618+ sampling_type ] = _top_k_top_p_multinomial_with_flashinfer (
619+ probs [long_sample_indices ],
620+ sampling_tensors .top_ks [long_sample_indices ],
621+ sampling_tensors .top_ps [long_sample_indices ],
622+ max_best_of_in_batch ,
623+ seq_groups_arg ,
624+ )
625+ else :
626+ multinomial_samples [sampling_type ] = _multinomial (
627+ probs [long_sample_indices ],
628+ max_best_of_in_batch ,
629+ seq_groups = seq_groups_arg )
574630
575631 if sampled_token_ids_tensor is not None :
576632 # Store sampled tokens in output tensor.
577- sampled_token_ids_tensor [
578- long_sample_indices ] = multinomial_samples [sampling_type ]
633+ sampled_token_ids_tensor [long_sample_indices ] = \
634+ multinomial_samples [sampling_type ]. to ( torch . long )
579635
580636 elif sampling_type == SamplingType .BEAM :
581637 beam_search_logprobs = logprobs [sample_indices ]
@@ -693,9 +749,12 @@ def _sample_with_triton_kernel(
693749
694750
695751def _sample (
696- probs : torch .Tensor , logprobs : torch .Tensor ,
697- sampling_metadata : SamplingMetadata , sampling_tensors : SamplingTensors ,
698- include_gpu_probs_tensor : bool , modify_greedy_probs : bool
752+ probs : torch .Tensor ,
753+ logprobs : torch .Tensor ,
754+ sampling_metadata : SamplingMetadata ,
755+ sampling_tensors : SamplingTensors ,
756+ include_gpu_probs_tensor : bool ,
757+ modify_greedy_probs : bool ,
699758) -> Tuple [SampleResultType , Optional [torch .Tensor ]]:
700759 """
701760 Args:
@@ -713,6 +772,7 @@ def _sample(
713772 probs ,
714773 logprobs ,
715774 sampling_metadata ,
775+ sampling_tensors ,
716776 include_gpu_probs_tensor = include_gpu_probs_tensor ,
717777 modify_greedy_probs = modify_greedy_probs ,
718778 )
0 commit comments