-
-
Notifications
You must be signed in to change notification settings - Fork 9.6k
[V1][Spec Decode] Optimize Rejection Sampler with Triton Kernels #14930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finished the rejection_sampler.py, will continue other files tonight
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
…m-project#14930) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
…m-project#14930) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
…m-project#14930) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
GREEDY_TEMPERATURE: tl.constexpr = -1 | ||
# Maximum number of speculative draft tokens allowed per request in a single | ||
# step. This value is chosen to be large enough to handle typical use cases. | ||
MAX_SPEC_LEN = 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @WoosukKwon , is there any limitation MAX_SPEC_LEN should be 32? Can it be larger? Thanks.
This PR optimizes the rejection sampler in #13933 with custom Triton kernels.
By using the Triton kernels, the PR brings the following benefits:
[num_tokens, vocab_size]
for the logits tensors, instead of[batch_size, max_spec_len, vocab_size]
. This reduces the GPU memory usage a lot.cat
,gather
, etc.)Performance benchmark: Llama 3.1 8B, ShareGPT, 1xH100, temperature 0.1
SD config:
--speculative-model "[ngram]" --ngram_prompt_lookup_min 5 --ngram-prompt-lookup-max 5 --num_speculative_tokens 3
25% throughput increase compared to main w/o SD, and 18% increase compared to main w/ SD.
Accuracy benchmark: GSM8K, Llama 3.1 8B Instruct, 5 shots