Skip to content

[V1][Spec Decode] Optimize N-gram matching with Numba #13365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
psutil
sentencepiece # Required for LLaMA tokenizer.
numpy < 2.0.0
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be in requirements-cuda.txt rather than common?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I'm ok with either; I just thought it would be eventually used by others as well. Please feel free to submit a PR to move it to requirements-cuda.txt and probably requirements-rocm.txt.

requests >= 2.26.0
tqdm
blake3
Expand Down
113 changes: 55 additions & 58 deletions vllm/v1/spec_decode/ngram_proposer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from typing import Optional

import numpy as np
from numba import jit


class NgramProposer:

def __init__(self):
pass

def propose(
self,
context_token_ids: np.ndarray,
Expand All @@ -21,7 +19,7 @@ def propose(
that match.

Args:
context_token_ids: List of token IDs representing the
context_token_ids: Numpy array of token IDs representing the
context sequence.
n: Length of the n-gram to match.
k: Number of tokens follow the match. If there are less
Expand All @@ -41,66 +39,65 @@ def propose(
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# TODO: Use c++ to implement the _find_subarray_kmp to
# improve the efficiency
return self._find_subarray_kmp(context_token_ids, n, k)
return _find_subarray_kmp(context_token_ids, n, k)

@staticmethod
def _kmp_lps_array(pattern: List[int]) -> List[int]:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps = [0] * len(pattern)
prev_lps = 0 # length of the previous longest prefix suffix
i = 1

while i < len(pattern):
if pattern[i] == pattern[prev_lps]:
prev_lps += 1
lps[i] = prev_lps
i += 1
@jit(nopython=True)
def _kmp_lps_array(pattern: np.ndarray) -> np.ndarray:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps = np.zeros(len(pattern), dtype=np.int32)
prev_lps = 0 # length of the previous longest prefix suffix
i = 1

while i < len(pattern):
if pattern[i] == pattern[prev_lps]:
prev_lps += 1
lps[i] = prev_lps
i += 1
else:
if prev_lps != 0:
prev_lps = lps[prev_lps - 1]
else:
if prev_lps != 0:
prev_lps = lps[prev_lps - 1]
else:
lps[i] = 0
i += 1
lps[i] = 0
i += 1
return lps

return lps

@staticmethod
def _find_subarray_kmp(
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0]
assert n > 0
@jit(nopython=True)
def _find_subarray_kmp(
context_token_ids: np.ndarray,
n: int,
k: int,
) -> Optional[np.ndarray]:
context_len = context_token_ids.shape[0]
assert n > 0

pattern = context_token_ids[-n:]
# Precompute lps array for Y
lps = NgramProposer._kmp_lps_array(pattern)
pattern = context_token_ids[-n:]
# Precompute lps array for Y
lps = _kmp_lps_array(pattern)

i = 0
j = 0
# -n because the last n tokens are used as pattern
while i < context_len - n:
if context_token_ids[i] == pattern[j]:
i += 1
j += 1
i = 0
j = 0
# -n because the last n tokens are used as pattern
while i < context_len - n:
if context_token_ids[i] == pattern[j]:
i += 1
j += 1

# If we have matched the entire Y
if j == n:
# Found pattern in context, gather the next K elements
return context_token_ids[i:i + k]
# If we have matched the entire Y
if j == n:
# Found pattern in context, gather the next K elements
return context_token_ids[i:i + k]
else:
# Mismatch
if j != 0:
# Use the lps array to avoid re-checking elements
j = lps[j - 1]
else:
# Mismatch
if j != 0:
# Use the lps array to avoid re-checking elements
j = lps[j - 1]
else:
i += 1
i += 1

# Y not found
return None
# Y not found
return None
13 changes: 11 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,20 @@ def __init__(
# Set up speculative decoding.
self.use_spec_decode = False
if self.speculative_config:
self.use_spec_decode = True

# TODO: find a better way to check if we are using ngram.
assert self.speculative_config.ngram_prompt_lookup_min, \
"Currently, only ngram spec decode is supported in V1."
self.drafter = NgramProposer()
self.use_spec_decode = True
if get_pp_group().is_last_rank:
self.drafter = NgramProposer()
# Trigger Numba JIT compilation for N-gram proposer.
# This usually takes less than 1 second.
self.drafter.propose(
np.zeros(1024, dtype=np.int32),
self.speculative_config.ngram_prompt_lookup_min,
self.speculative_config.num_speculative_tokens,
)

# Request states.
self.requests: Dict[str, CachedRequestState] = {}
Expand Down