-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce speculative decoding with draft models to vLLM #3029
Conversation
2585153
to
5529f29
Compare
Hi, this doesn't seem to work with ray atm. I was trying to benchmark Mixtral8x7b with Mistral 7b as the draft model.
|
5529f29
to
a54ec05
Compare
Ray worker initialization added. |
Ah I see, np. I was trying to test with tp 8 on 8xA10 setup. |
also im getting this error when trying to run
Im installing like
|
a54ec05
to
5dc3688
Compare
Should be fixed now. Ray worker initialization has been added as well (not tested on ray, but it basically shares the same logic as before). |
5dc3688
to
95db856
Compare
Just tried this again, with tp 8. It freezes after
Something does seem to be loaded into memory though
Im running the benchmark file
|
95db856
to
ac70ddd
Compare
Thanks for the feedback. Could you please help to confirm that
Thanks! |
ac70ddd
to
7142592
Compare
Yes sure. Main Branch
On PR
|
Wait it just worked this time (EDIT, this is lookahead 1 so not using spec decode)
Will test some more combinations |
Thanks for reporting! I have just added some my previous experiment result to PR description. |
When I use
It freezes. Using lookahead 1 works fine |
Have you tested if it works with using the |
In your experience, how long does the benchmark take? |
The main model and target model shares the same code path, thus there shouldn't be many differences in model loading (except the model size). |
I think there's something wrong with my setup. It stops after loading one model. |
I cannot find a 8 GPU machine to run un-quantized MIxtral 8x7b and Mistral 7b model, but I can confirm that using the GPTQ version |
9385c9a
to
1642fa3
Compare
Signed-off-by: Tao He <sighingnow@gmail.com>
Signed-off-by: Tao He <sighingnow@gmail.com>
Signed-off-by: Tao He <sighingnow@gmail.com>
1642fa3
to
6c87e92
Compare
|
||
if input_metadata.use_flash_attn: | ||
# see also: https://github.com/Dao-AILab/flash-attention/commit/54e80a3829c6d2337570d01e78ebd9529c02d342 | ||
output = flash_attn_with_kvcache( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for sharing the implementation. Is flash_attn_with_kvcache
faster than context_attention_fwd
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't benchmark these two kernels, but I have added a end-to-end speculative sampling benchmark in the last part of #3010's pull request description.
I met the same issue when tp=4. It freeze after
The command is as following: python3 benchmarks/benchmark_latency.py \
--model ${model} \
--draft-model ${draft_model} \
--temperature 1.0 \
--parallel-decoding-lookahead 4 \
--enforce-eager \
--tensor-parallel-size ${tp} \
--input-len ${lens} \
--output-len ${lens} \
--batch-size ${bs} \
--num-iters ${iters} 2>&1|tee -a ${log_file} The command works fine with |
The data is interesting. I think sampling params such as |
Another concern: Speculative decoding will yield lower TPOT while higher TTFT. Is it possible to skip the first generation token? To make it more clear, the first generation token is generated by target model only. Speculative decoding is only applied to 2nd and the following tokens. |
Will take a try.
I just set temperature for better reproducibility in benchmarks. From my experiences,
It might cannot be archived, I think. As the autoregressive steps of draft models requires prompt tokens' kv-cache. There exists many other speculative decoding design, e.g., self-speculative decoding, which doesn't require a draft model. Such decoding strategy is not included in this PR but is compatible with current design. |
The accept rate varies between datasets as well. |
@sighingnow Thanks for your reply. The main block for me is the TTFT. So I have to find some other ways to solve it. |
If I understand it correctly,
|
@sighingnow Hi Tao, have you solved the tp>1 error? As speculative decoding is mainly used for models with large params, which tensor parallel is needed. Looking forward to your good news :) |
Automatic prefix cache #2762 has been merged but there is not performance increase. The performance issue is in TODO list. |
Haven't tried yet. I would try to reproduce. At the same time, I have noticed @cadedaniel has submitted [3/9] PR (#3103) for speculative decoding, I would like to know how the vLLM community think about the plan of the development plan of speculative decoding. If this PR is confirmed won't be accepted/merged or even partially accepted/merged, refine it may won't help much in any way. |
Thanks for the work on this -- it's a good PR :). To answer your question, @simon-mo @LiuXiaoxuanPKU, myself, and the authors of #2607 met a few weeks ago and @simon-mo wants to go with #2188 first. The key idea is that it refactors some vLLM internals so that different types of speculative decoding is supported; e.g. prompt-lookup, RAG acceleration, topk+top1 Eagle/Medusa, typical acceptance, etc. I'll be working full-time this week to finish the PRs. After the correctness tests are in, happy to collaborate/accept optimizations. |
Thank you, @cadedaniel! ❤️ Looking forward to the progress of #2188 and this RP actually benefits from it as well. I will see if there are still something that is worth submitting to vLLM after this series of PRs been merged. |
@sighingnow Thanks for your awesome work! I'm recently working on integrating https://github.com/SafeAILab/EAGLE with vLLM. EAGLE leverages tree attention to score the draft tokens, which applies an attention mask (different from the causal attention mask) to attention scores. However, current attention implementation, including flash-attn and flashinfer, seems not able to handle that. Do you have any suggestion? Any guidance you can offer would be greatly appreciated. |
For tree-attention, one possible way might be forking the tree to several sequences as a batch (their prefix shares the kv-cache)? I think hacking into the triton kernel |
This might caused by cache miss and the inefficiency of the triton kernel. Flash-attn may help. |
True, and that's what I'm doing with flashinfer. Thanks for your advice :) |
"--parallel-decoding-lookahead", | ||
type=int, | ||
default=1, | ||
help="Number of lookahead steps for speculativespeculative decoding.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double spell
This PR is (yet another) implementation of speculative decoding on vLLM, compared with existing efforts (including #2607, #1797, and #1679), this PR:
context_attention_fwd
from the prefix cache PR is enough, add is compatible with further effort of introducing flash-attn and flashinfer)(Note that this PR is built-upon the PR #3007 (GQA fixes for
context_attention_fwd
) and #3010 (introducing flash-attn to vLLM), but not depends on these two PRs. If these two PRs can be accepted, I would rebase this PR then, otherwise submitting the speculative sampling feature in a separate PR also works for me.)The major change only happens in
llm_engine.py
'sstep()
method,model_runner.py
's_prepare_decode()/_prepare_sample()
method, should be fairly easy for code review.The major design & implementation can be highlighted as follows:
num_gpu_blocks
is computed fromgpu_memory / (draft_block_size + target_block_size)
LLMEngine
'sstep()
method will run draft model for k times and then a target model step follows.context_attention_fwd
kernel (originally added for prefix caching)flash_attn_with_kvcache
kernel.transformer
package for how to decide if a token should be accepted, more specifically,temperature=0.0
)TODO:
NotImplementedError
for such case and leave it as a TODO.Support initialize draft workers on Ray.Numbers of a prompt randomly choosed from dataset, using Llama-2-70B-GPTQ as the target model and Llama-2-7B-GPTQ/TinyLLama-1.1B-Chat-v1.0-GPTQ as the draft model: