Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Speculative decoding using a draft model #2188

Closed
wants to merge 1 commit into from

Conversation

cadedaniel
Copy link
Collaborator

Speculative decoding

This PR adds speculative decoding to vLLM using the draft model approach. This was first explored in these papers:

As a simplified overview, this type of speculative decoding runs a smaller draft model to guess what the larger target model will emit. The target model then verifies the guesses, and emits them if they pass verification. This yields a latency reduction as the verification of many tokens can happen in a single forward pass of the target model.

Running this at Anyscale, I see a 30-50% latency reduction depending on the draft model, target model and dataset.

Usage

The usage looks like this:

llm = LLM(
    model="meta-llama/Llama-2-13b-chat-hf",
    speculative_model="JackFram/llama-68m", # The draft model. Must have same vocabulary as target model.
    tensor_parallel_size=4,
    speculative_model_uses_tp_1=True, # Whether the draft model should use TP=1 or same TP as target model.
    num_speculative_tokens=3, # The number of speculative tokens to score.
)

Feature list

  • Any vLLM model can be used as the draft model as long as it has the same vocabulary as the target model.
  • Rejection sampling to guarantee target model distribution.
  • The draft model can use tensor-parallel degree of 1, or the same tensor-parallel degree as the target model. This reduces draft model latency and contributes to the overall speedup.
  • The vLLM scheduler has been improved to support >1 token per step.
  • Tests for rejection sampling, various workers, scheduler, and e2e speculative decoding.

Future work (not implemented)

  • Top-k speculative decoding ("Tree attention") to increase draft hitrates
  • Optimized PagedAttention (multi-query attention) for scoring
  • Beam-sampling + speculative decoding

This PR is marked as draft as there is nontrivial work required to get this into a mergable state.

Guide for reviewers

The following are key files for understanding this speculative decoding implementation:

  • Rejection sampler. Use rejection sampling to approximate the target distribution using samples from the draft distribution.
  • Draft target worker. A worker which contains both the draft and target models. This orchestrates the drafting of speculative continuations, scoring, and accepting speculative tokens based on rejection sampling.
  • Multi-step worker. This runs the draft model several times, without invoking the vLLM scheduler each time.
  • Single-TP worker. This allows the DraftTargetWorker to run the draft model with TP=1, while the target model is TP=2, 4 or 8. This reduces latency of the draft model.
  • Key scheduler change. The scheduler and block manager now have the notion of "preallocated slots". This allows them to schedule KV block space sufficient for the worker to run several steps before the next scheduler invocation.
  • Key sequence changes. Sequences now track "processed tokens". A processed token is one that has been processed by the model such that the KV activations have been saved to cache.
  • Performance optimizations. Significant modifications have been made to the sampler so that the draft model can run several steps without CPU synchronization. This plus the SingleTPWorker reduced draft model latency from 5ms to 1.5ms.

@Lvjinhong
Copy link

Lvjinhong commented Dec 19, 2023

I'm very fortunate to witness such great work from you. How is the current progress? May I use your method to accelerate the llama2 70b on vLLM for now?

@cadedaniel
Copy link
Collaborator Author

I'm very fortunate to witness such great work from you. How is the current progress? May I use your method to accelerate the llama2 70b on vLLM for now?

hi @Lvjinhong. The current PR requires some work to get into a working state on the public vllm repo. I will start on this later this week but given US holidays I expect to finish early january.

@zhaoyang-star
Copy link
Contributor

Very glad to see this work. I also have a wip version of speculative decoding and looking forward to using this feature.

@cadedaniel
Copy link
Collaborator Author

Created a plan to break this PR into separate pieces. Pending review from vLLM original authors, I will start on it this week. https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit

@wasertech
Copy link

wasertech commented Jan 17, 2024

So from the little I understood I think there is a way to avoid having a draft model for your vocabulary by using directly the model's ngrams? It should probably be the subject of a separate PR but I think this is the way forward for everyone to easily enjoy the benefits of speculative decoding.

What do you think @cadedaniel, do you think its as straigh forward as @simon-mo puts it?

I would imagine after this PR, ngram support should be very straightforward.

@cadedaniel
Copy link
Collaborator Author

Yep. This PR builds the framework for scoring and verifying draft tokens, independent of whether they come from a draft model or something like Medusa, or lookup ngrams like prompt lookup or RAG lookup.

@wasertech
Copy link

wasertech commented Jan 17, 2024

I think I understand a bit better now thank you, @cadedaniel , again for this incredible work!

@UranusSeven
Copy link
Contributor

UranusSeven commented Jan 19, 2024

@cadedaniel Hi! Thanks for your great work!

As I understand it, during the prefill stage, we are going to run the draft model once and the target model once. Will this increase the first token latency?

@cadedaniel
Copy link
Collaborator Author

Yes, the time to first token is a few milliseconds higher for a draft model of size 68m. It can be optimized in future versions, e.g. with Medusa/EAGLE where draft tokens are generated without independent kv cache.

@UranusSeven
Copy link
Contributor

@cadedaniel Thanks for your reply! I also want to confirm my understanding regarding the decoding step's impact on first token latency:

  • Generating k draft tokens per iteration introduces a delay of k * d milliseconds, where d is the time to run the draft model once.
  • Additionally, the target model's execution time of t milliseconds further contributes to the overall latency.
  • This means the total decoding time is approximately k * d + t milliseconds. The cost of sampling is ignored.
  • Consequently, newly arriving prefill requests might experience a wait of up to k * d + t milliseconds before execution.

I'm wondering if it helps to break the decoding step into 2 sub-steps, drafting and verifying. In this way, newly arriving prefill requests wait up to k * d milliseconds. And, if the newly arriving prompts are short, they can be batched with the verifying step.

What do you think?

@cadedaniel
Copy link
Collaborator Author

@UranusSeven could you ask your question in a discussion post? happy to answer there

@SinanAkkoyun
Copy link

SinanAkkoyun commented Jan 30, 2024

Hi, thanks for the great work! Do you have some speed tps benchmarks? I'd like to use SD with deepseeks 33B and 1B models

Also, does this PR support an SD model with a different tokenizer than the main model? (for example llama with deepseek SD model)

@qizzzh
Copy link

qizzzh commented Jan 31, 2024

Out of curiosity, does the proposal support separate deployment for the draft and target model? Asking because in production these two likely have different QPS and computing resource requirements.

@zhuohan123 zhuohan123 mentioned this pull request Jan 31, 2024
30 tasks
Copy link

@xunfeng1980 xunfeng1980 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm-public/vllm/config.py", line 62
    def __init__(self,
                ^
SyntaxError: '(' was never closed

cuda_graph_max_context_len: int = 5000,
cuda_graph_cache_size: int = 10,
flash_style: bool = False,
max_chunked_prefill_len: int = -1,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

) ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm-public/vllm/engine/async_llm_engine.py", line 7, in <module>
    from vllm.anyscale.lora.utils import LoRARequest
ModuleNotFoundError: No module named 'vllm.anyscale'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the same issue. Any update on this?

@cadedaniel
Copy link
Collaborator Author

Hi everyone, want to provide a quick update here:

  • Last few weeks I've prioritized optimizing Mixtral latency Optimized fused MoE Kernel #2913
  • Now I will focus on getting this merged full-time. I am aiming to finish merges by February with @LiuXiaoxuanPKU's help reviewing.
  • After the correctness tests are merged, I will accept any optimizations that pass correctness tests. I will list out some major optimizations that people can take on (and already some tech discussions happening on MQA cc @ymwangg @robertgshaw2-neuralmagic ).

@binarycrayon
Copy link

binarycrayon commented Feb 20, 2024 via email

@ymwangg
Copy link
Contributor

ymwangg commented Feb 21, 2024

Thanks for the update! Looking forward to seeing how to incorporate our work once your PRs are out.

@hchoi-moveworks
Copy link

Thanks @cadedaniel for truly inspiring work! 🙏

Would Speculative decoding work with vLLM's continuous batching as well? Would that be the step Integrate speculative decoding with LLMEngine in our proposed design doc?
https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit

@cadedaniel
Copy link
Collaborator Author

Thanks @cadedaniel for truly inspiring work! 🙏

Would Speculative decoding work with vLLM's continuous batching as well? Would that be the step Integrate speculative decoding with LLMEngine in our proposed design doc?

https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit

Yep! Once the e2e correctness tests pass you can use it with continuous batching.

@simon-mo simon-mo mentioned this pull request Apr 4, 2024
65 tasks
@paolovic
Copy link

Hi @cadedaniel ,
thanks for your great contribution!
Since I'm also very keen on this feature, can I support you somehow?

@cadedaniel
Copy link
Collaborator Author

Thanks! It would be helpful if you could test it out and add features or optimizations. Once #3951 is merged it will be correct but not yet fast; I'll post more about optimizing it this week.

@paolovic
Copy link

Thanks! It would be helpful if you could test it out and add features or optimizations. Once #3951 is merged it will be correct but not yet fast; I'll post more about optimizing it this week.

Great, I'll start tomorrow (CET)

@cocoza4
Copy link

cocoza4 commented Apr 25, 2024

will AssertionError: Speculative decoding not yet supported for RayGPU backend. error go away if this PR is merged? Ref: #4358

@HimanshuJanbandhu
Copy link

HimanshuJanbandhu commented May 27, 2024

Hi guys,
First of all great to see your work on speculative decoding.
However, there is a new paper, Draft & Verify: Lossless Large Language Model Acceleration via
Self-Speculative Decoding (https://arxiv.org/abs/2309.08168), which talks about how we can create a draft model, by skipping layers from the original LLM.
There's another paper recently by Meta, that talks about it, LayerSkip: Enabling Early Exit Inference and
Self-Speculative Decoding - https://arxiv.org/pdf/2404.16710.
This method leads further optimization of memory as it is using the same model for draft and verify, creating draft model just by skipping some layers from the original model.
Would love to see this implemented in vLLM, as I am keen on this feature, can I support you somehow?
Here is the official implementation of ""Draft & Verify: Lossless Large Language Model Acceleration via
Self-Speculative Decoding"" provided https://github.com/dilab-zju/self-speculative-decoding/tree/main

@w32zhong
Copy link

w32zhong commented May 31, 2024

Hi @HimanshuJanbandhu, I would like also mention my recent work S3D (https://arxiv.org/abs/2405.20314). It is very similar to your mentioned Self-Speculative Decoding work which is simple to implement and easy to be integrated to existing stacks. But we have achieved better efficiency in general (compared to Self-Spec), our method combines layer-skipping with multiple next-token generation/unmasking. Although ours requires a bit training, but it should be straightforward just like training a Transformer encoder like BERT.

@cadedaniel
Copy link
Collaborator Author

We welcome a self-speculative implementation!

@paolovic
Copy link

paolovic commented Jun 4, 2024

Thanks! It would be helpful if you could test it out and add features or optimizations. Once #3951 is merged it will be correct but not yet fast; I'll post more about optimizing it this week.

Great, I'll start tomorrow (CET)

by the way: sorry for volunteering and never coming back to you. Unfortunately, I am and will be reaaally busy until beginning of August ✌🏻

But thank you very much for your efforts!!!

@skylee-01
Copy link

Hello, Teacher, it is a great honor to witness your magnificent work. I have been studying you SpecDecode work recently, and I have a question that I hope you can guide me on. Why does SpecDecodeWorker inherit from LoraNotSupportedWorkerBase? Why doesn't the current SpecDecode support Lora? @cadedaniel
screenshot-20240805-113310

@cadedaniel
Copy link
Collaborator Author

Hi @skylee-01 . See #6912 for the work required to add LoRA + spec decode.

@cadedaniel
Copy link
Collaborator Author

Thanks for the interest everyone! With #4630 (comment), all the work in this PR has been merged.

@cadedaniel cadedaniel closed this Aug 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.