Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Discussion] Will vLLM consider using Speculative Sampling to accelerating LLM decoding? #1171

Closed
gesanqiu opened this issue Sep 25, 2023 · 25 comments

Comments

@gesanqiu
Copy link
Contributor

gesanqiu commented Sep 25, 2023

Sampling is an already known bottleneck of vLLM(see #421 and #670 ). Last weekend I saw a project named Medusa, in it's blog, it introduce a new simple decoding way to accelerate LLM generation and reach a good performance. As far as I known, lepton.ai is alreay use this method.
Adopting Medusa Heads is not difficult, since there is no seperate model. But tree attention and typical acceptance scheme is not a standard process for most LLM inference framework and should take a huge effort.

Any advice or comments?

@void-main
Copy link

I'm actually trying to integrate Medusa to vLLM.

After some quick and dirty experiments, I found out what really makes the integration hard is PagedAttention itself, which, by the way, is one of the core features of vLLM.

The whole vLLM system is built with an assumption that in decoder phase, there will be ONE newly generated token which ties to ONE kv cache block for each sequence. With speculative decoding methods, like Medusa, this assumption won't hold anymore.

Though it's hard, I believe it's totally doable. Here are some places we might want to tweak:

  • should take candidates into consideration when you _prepare_inputs in the worker
  • update block tables when you accept one candidate and release other blocks
  • change the semantics of forward, make it return a list of tokens in SequenceOutputs
  • set medusa tree mask with custom_attention_mask, and take mask into consideration in single_query_cached_kv_attention

Working on a PoC demo (only supports single running query), hopefully could share some new thoughts and findings later.

@Data-drone
Copy link

How is your testing going @void-main ?

@void-main
Copy link

Still working on it. Needs some modification to PagedAttention, which is the core to vLLM. @Data-drone

@Data-drone
Copy link

Is there a branch I can have a look at?

@void-main
Copy link

Currently the code change is in a private repo in my company. Later we'd like to release the working version.

@InkdyeHuang
Copy link

Currently the code change is in a private repo in my company. Later we'd like to release the working version.

when input_token_len = 450 ouput_token_len = 150, the first prompt step time and the second generate time is 1:1, so when the accurate of draftmodel is 96%, the speedtime is 20%, only the accurate is 30% can cover additional consumption

@zhaoyang-star
Copy link
Contributor

zhaoyang-star commented Nov 16, 2023

Thank @void-main for the sharing the progress on porting Medusa.

I am porting Speculative Decoding into vLLM. I also found the main blocker is PagedAttention after some quick and dirty codes. More clearly, in speculative decoding mode, more than ONE tokens need to be taken as input when KV Cache has already existed. However, PagedAttention only supports two situations:

  • Prefill stage: NO KV Cache exist. Calling xformers.xops.memory_efficient_attention_forward function internally.
  • Decoding stage: KV Cache has already existed and ONLY ONE generated new tokens per prompt. Calling handwritten CUDA kernel paged_attention_v1/v2 internally.

The main blocker for the porting work happend at decoding stage. I should create a new kernel which could take more than one generated new tokens and existing KV Cache as input. The new kernel may have large difference to the paged_attention_v1/v2 kernel.

By now HuggingFace Transformers and llama.cpp has supported Speculative Decoding or similar method.

I have found more and more vLLMers consider using Speculative Decoding to accelerating LLM inference. Hope free discussion here and any suggestions are welcome. @WoosukKwon @zhuohan123 @casper-hansen @Yard1

@void-main
Copy link

@zhaoyang-star

You're right on the decoding stage, but one more thing, when doing medusa SpS, the m dimension for a single sequence is no longer 1, which means you should consider using Tensor Core for performance. I've implemented a kernel called PagedFlashAttention to do exactly that.

@void-main
Copy link

void-main commented Nov 17, 2023

Recent update: I got a working demo version with vLLM + Medusa, for a single sentence, the average accept length is 1.8 (which means you get 1.8 tokens per forward on average), the performance boots 30%+ compared to vllm.

Here's the demo video:
https://github.com/vllm-project/vllm/assets/552990/0f22d79c-262c-4bfb-ab67-4c6ccab80dd5

Left part is vLLM + Medusa, right side part is pure vLLM.

The result is pretty interesting, vLLM + Medusa is slower on first token, because Medusa doesn't produce any token on first forward pass. But Medusa catches up soon, and become faster on later tokens.

@void-main
Copy link

Notes on the PagedFlashAttention kernel, as you can tell from the name, it combines PagedAttention with FlashAttention.

PagedAttention only works for decoding stages where you generate 1 token for each sequence, so you could use CUDA core to calculate attention score for each sequence. But with tree candidates from medusa, for each sequence, you need to process ~7-30 candidates. Sticking to CUDA core would make it too slow to get any benefits.

But, wait a sec, processing many tokens at once, that's what FlashAttention does. The only blocker is vLLM pages KV caches that needs to be used during generation phase. So I wrote a OpenAI triton kernel to do that. The code is based off lightllm's tokenattention, and added some page-related loading.

I'd like to say it's a pretty fun journey to implement PagedFlashAttention with triton, but I ran into several issues(triton-lang/triton#2488, triton-lang/triton#2522 and triton-lang/triton#2637), and the issues are not resolved yet, so hopefully we could optimize the kernel later.

@void-main
Copy link

I have found more and more vLLMers consider using Speculative Decoding to accelerating LLM inference. Hope free discussion here and any suggestions are welcome. @WoosukKwon @zhuohan123 @casper-hansen @Yard1

And I totally agree with @zhaoyang-star , vLLM is a great framework, but the whole framework is based on the assumption that each forward pass generates 1 token. Maybe later we should propose an RFC (maybe named VFC?) to extend this.

@AlvL1225
Copy link

AlvL1225 commented Dec 6, 2023

I have found more and more vLLMers consider using Speculative Decoding to accelerating LLM inference. Hope free discussion here and any suggestions are welcome. @WoosukKwon @zhuohan123 @casper-hansen @Yard1

And I totally agree with @zhaoyang-star , vLLM is a great framework, but the whole framework is based on the assumption that each forward pass generates 1 token. Maybe later we should propose an RFC (maybe named VFC?) to extend this.

Hi, does medusa or speculative decoding support topp or topk sampling?

@zhaoyang-star
Copy link
Contributor

zhaoyang-star commented Dec 6, 2023

I have found more and more vLLMers consider using Speculative Decoding to accelerating LLM inference. Hope free discussion here and any suggestions are welcome. @WoosukKwon @zhuohan123 @casper-hansen @Yard1

And I totally agree with @zhaoyang-star , vLLM is a great framework, but the whole framework is based on the assumption that each forward pass generates 1 token. Maybe later we should propose an RFC (maybe named VFC?) to extend this.

Hi, does medusa or speculative decoding support topp or topk sampling?

Speculative decoding supports temperature/topk/topp.

@Lvjinhong
Copy link

感谢@void-main分享移植 Medusa 的进展。

我正在将推测解码移植到 vLLM 中。在一些快速而肮脏的代码之后,我还发现主要的拦截器是 PagedAttention。更明确的是,在推测解码模式下,当 KV Cache 已经存在时,需要将多个 token 作为输入。 但是,PagedAttention 只支持两种情况:

  • 预填充阶段:不存在 KV 缓存。内部调用 xformers.xops.memory_efficient_attention_forward 函数。
  • 解码阶段:KV 缓存已经存在,并且每次提示时只有一个生成新令牌。内部调用手写的CUDA内核paged_attention_v1/v2。

**移植工作的主要阻碍发生在解码阶段。**我应该创建一个新内核,它可以将多个生成的新令牌和现有 KV 缓存作为输入。 新内核可能与 paged_attention_v1/v2 内核有较大差异。

现在 HuggingFace 变形金刚llama.cpp 已支持推测解码或类似方法。

我发现越来越多的 vLLM 考虑使用推测解码来加速 LLM 推理。希望在这里自由讨论,欢迎提出建议。@WoosukKwon @zhuohan123 @casper-hansen @Yard1

I'm impressed with your excellent work. May I inquire about the current progress? Has speculative decoding been implemented in vLLM?
Additionally, regarding the performance comparison between lightLLM and vLLM, especially with the llama2 70b model, I have some questions. Do you know of any test results that could help me quickly decide on a better framework? Thank you very much.

@RonanKMcGovern
Copy link
Contributor

As far as I understand, the medusa approach requires training of the added attention heads to be used for look-ahead. This makes it much harder to support a wide variety of models.

Starting with simple n-gram may be best, just to get the feature out and give some speedup.

@Lvjinhong
Copy link

As far as I understand, the medusa approach requires training of the added attention heads to be used for look-ahead. This makes it much harder to support a wide variety of models.

Starting with simple n-gram may be best, just to get the feature out and give some speedup.

Thank you. Currently, I am only aware that the performance bottleneck of vLLM lies in the decoding stage. Based on your experience, if I, as an individual, want to enhance the performance of vLLM specifically for Llama, are there any feasible solutions to achieve better results?

@RonanKMcGovern
Copy link
Contributor

@Lvjinhong it's possible to store n-grams - either from the prompt or the generated text. Every pass forward, you can see if there are matching n-grams that help you guess the following n tokens. You then include those tokens in the forward pass and can keep all of them if they are correct (or part if partly correct). If none are correct, you can at least use those tokens to add to your ngram list.

I believe this is what --tgi does with the --speculate flag.

@irasin
Copy link
Contributor

irasin commented Dec 25, 2023

I have found more and more vLLMers consider using Speculative Decoding to accelerating LLM inference. Hope free discussion here and any suggestions are welcome. @WoosukKwon @zhuohan123 @casper-hansen @Yard1

And I totally agree with @zhaoyang-star , vLLM is a great framework, but the whole framework is based on the assumption that each forward pass generates 1 token. Maybe later we should propose an RFC (maybe named VFC?) to extend this.

Hi, does medusa or speculative decoding support topp or topk sampling?

Speculative decoding supports temperature/topk/topp.

Hi, I was wondering why speculative decoding support temperature/top-k/top-p sampling?
Since it use argmax result of draft model, I think it only supports greedy sampling.

@Moran232
Copy link

Recent update: I got a working demo version with vLLM + Medusa, for a single sentence, the average accept length is 1.8 (which means you get 1.8 tokens per forward on average), the performance boots 30%+ compared to vllm.

Here's the demo video: https://github.com/vllm-project/vllm/assets/552990/0f22d79c-262c-4bfb-ab67-4c6ccab80dd5

Left part is vLLM + Medusa, right side part is pure vLLM.

The result is pretty interesting, vLLM + Medusa is slower on first token, because Medusa doesn't produce any token on first forward pass. But Medusa catches up soon, and become faster on later tokens.

hi,how is the performance for multiple sentences(e.g. batchsize = 32/64)?

@void-main
Copy link

Hi @Moran232 , Medusa performs worse for large batch sizes, here's my test result:

CleanShot 2023-12-28 at 14 57 01

CleanShot 2023-12-28 at 14 57 20

Medusa beats vLLM on small batches (BS < 8), but fails on larger batches.

@wasertech
Copy link

wasertech commented Jan 18, 2024

You might want to know that @cadedaniel is working on a PR that introduces a framework to score and verify draft tokens. That would allow vLLM to benefit from speculative decoding from Medusa or directly from your target model's ngrams w/ vLLM in #2188 🔥

@dutsc
Copy link

dutsc commented Mar 20, 2024

@Lvjinhong it's possible to store n-grams - either from the prompt or the generated text. Every pass forward, you can see if there are matching n-grams that help you guess the following n tokens. You then include those tokens in the forward pass and can keep all of them if they are correct (or part if partly correct). If none are correct, you can at least use those tokens to add to your ngram list.

I believe this is what --tgi does with the --speculate flag.

Hello @RonanKMcGovern ! Do the store n-grams you mentioned mean the same thing as the n-grams mentioned in Lookaheaddecoding? In other words, assuming n=3, abc, def, xyz are stored in the grams list. When my prompt is forwarded and I get 123789a, can I directly guess the output as 123789abc based on the 3-grams list?

@RonanKMcGovern
Copy link
Contributor

RonanKMcGovern commented Mar 20, 2024 via email

@dutsc
Copy link

dutsc commented Mar 20, 2024

I believe the TGI implementation does not use the Jacobi method. It is a bland build of ngrams using both the prompt AND tokens generated to date. I have to admit I don't grasp exactly how they build the ngrams. It may be simple pattern matching of past sequences to the latest token.

On Wed, Mar 20, 2024 at 3:26 AM Chen Shen @.> wrote: @Lvjinhong https://github.com/Lvjinhong it's possible to store n-grams - either from the prompt or the generated text. Every pass forward, you can see if there are matching n-grams that help you guess the following n tokens. You then include those tokens in the forward pass and can keep all of them if they are correct (or part if partly correct). If none are correct, you can at least use those tokens to add to your ngram list. I believe this is what --tgi does with the --speculate flag. Hello @RonanKMcGovern https://github.com/RonanKMcGovern ! Do the store n-grams you mentioned mean the same thing as the n-grams mentioned in Lookaheaddecoding https://lmsys.org/blog/2023-11-21-lookahead-decoding/? In other words, assuming n=3, abc, def, xyz are stored in the grams list. When my prompt is forwarded and I get 123789a, can I directly guess the output as 123789abc based on the 3-grams list? — Reply to this email directly, view it on GitHub <#1171 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASVG6CXF55VQJ4X3SX67C4LYZD6UTAVCNFSM6AAAAAA5FTWOHSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMBYGYYDONBVGQ . You are receiving this because you were mentioned.Message ID: @.>

Thank you very much for your answer!

@hmellor
Copy link
Collaborator

hmellor commented May 18, 2024

Feature request for this #1023

@hmellor hmellor closed this as completed May 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests