Skip to content

Commit 08577f8

Browse files
WoosukKwonMu Huai
authored andcommitted
[V1][Spec Decode] Optimize Rejection Sampler with Triton Kernels (vllm-project#14930)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent b07885d commit 08577f8

File tree

8 files changed

+898
-431
lines changed

8 files changed

+898
-431
lines changed

tests/v1/sample/test_rejection_sampler.py

Lines changed: 166 additions & 65 deletions
Large diffs are not rendered by default.

vllm/envs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
VLLM_TRACE_FUNCTION: int = 0
3636
VLLM_ATTENTION_BACKEND: Optional[str] = None
3737
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
38-
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
3938
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
4039
VLLM_PP_LAYER_PARTITION: Optional[str] = None
4140
VLLM_CPU_KVCACHE_SPACE: int = 0

vllm/v1/outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class SamplerOutput:
4646
# [num_reqs, max_num_generated_tokens]
4747
# Different requests can have different number of generated tokens.
4848
# All requests are padded to max_num_generated_tokens.
49-
# INVALID_TOKEN_ID (-1 by default) is used for padding.
49+
# PLACEHOLDER_TOKEN_ID (-1 by default) is used for padding.
5050
sampled_token_ids: torch.Tensor
5151
logprobs_tensors: Optional[LogprobsTensors]
5252

vllm/v1/sample/ops/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import Union
3+
4+
import torch
5+
6+
7+
def compiled_softmax(
8+
logits: torch.Tensor,
9+
temperature: Union[float, torch.Tensor] = 1.0,
10+
) -> torch.Tensor:
11+
"""Faster softmax kernel generated by torch.compile.
12+
13+
Args:
14+
logits: [n, vocab_size]
15+
temperature: [n] or float
16+
"""
17+
# NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic.
18+
torch._dynamo.mark_dynamic(logits, index=0)
19+
if isinstance(temperature, torch.Tensor):
20+
torch._dynamo.mark_dynamic(temperature, index=0)
21+
return _softmax(logits, temperature)
22+
23+
24+
@torch.compile
25+
def _softmax(
26+
logits: torch.Tensor,
27+
temperature: Union[float, torch.Tensor],
28+
) -> torch.Tensor:
29+
logits = logits / temperature
30+
return torch.softmax(logits, dim=-1, dtype=torch.float32)

vllm/v1/sample/rejection_sampler.py

Lines changed: 506 additions & 292 deletions
Large diffs are not rendered by default.

vllm/v1/spec_decode/metadata.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from dataclasses import dataclass
3+
4+
import numpy as np
5+
import torch
6+
7+
8+
@dataclass
9+
class SpecDecodeMetadata:
10+
11+
# [num_tokens]
12+
draft_token_ids: torch.Tensor
13+
# [batch_size]
14+
num_draft_tokens: list[int]
15+
# [batch_size]
16+
cu_num_draft_tokens: torch.Tensor
17+
# [num_tokens]
18+
target_logits_indices: torch.Tensor
19+
# [batch_size]
20+
bonus_logits_indices: torch.Tensor
21+
# [num_tokens + batch_size]
22+
logits_indices: torch.Tensor
23+
24+
def __post_init__(self):
25+
self.max_spec_len = max(self.num_draft_tokens)
26+
27+
@classmethod
28+
def make_dummy(
29+
cls,
30+
draft_token_ids: list[list[int]],
31+
device: torch.device,
32+
) -> "SpecDecodeMetadata":
33+
batch_size = len(draft_token_ids)
34+
num_draft_tokens = [len(ids) for ids in draft_token_ids]
35+
flattened_draft_token_ids = sum(draft_token_ids, [])
36+
num_tokens = len(flattened_draft_token_ids)
37+
38+
draft_token_ids_tensor = torch.tensor(flattened_draft_token_ids,
39+
dtype=torch.int32,
40+
device=device)
41+
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
42+
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(
43+
device)
44+
45+
target_logits_indices = torch.zeros(num_tokens,
46+
dtype=torch.int32,
47+
device=device)
48+
bonus_logits_indices = torch.zeros(batch_size,
49+
dtype=torch.int32,
50+
device=device)
51+
logits_indices = torch.zeros(num_tokens + batch_size,
52+
dtype=torch.int32,
53+
device=device)
54+
return cls(
55+
draft_token_ids=draft_token_ids_tensor,
56+
num_draft_tokens=num_draft_tokens,
57+
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
58+
target_logits_indices=target_logits_indices,
59+
bonus_logits_indices=bonus_logits_indices,
60+
logits_indices=logits_indices,
61+
)

vllm/v1/spec_decode/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
from vllm.v1.sample.ops.topk_topp_sampler import random_sample # noqa
32
from vllm.v1.worker.gpu_input_batch import InputBatch
43

54

0 commit comments

Comments
 (0)