Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative Decoding #2607

Closed
wants to merge 2 commits into from
Closed

Conversation

ymwangg
Copy link
Contributor

@ymwangg ymwangg commented Jan 26, 2024

Introduction

Similar to #1797, #2188, this PR implements speculative decoding. The goal of this PR is to facilitate research in this direction, such as developing new draft token generation mechanisms, new sampling method, optimized CUDA kernels, while the vLLM community is settling on the infrastructure part.

Example Usage

You can find two example scripts examples/api_client_spec_dec.py and examples/offline_inference_spec_dec.py.

The fastest way to try out this feature is to run the following commands:

  • Launch api server with speculative decoding enabled:
python -m vllm.entrypoints.api_server --model lmsys/vicuna-13b-v1.5 --tensor-parallel-size 1 --draft-model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --speculate-length 5
  • Run query:
python api_client_spec_dec.py --stream --prompt "What's the best way to start learning a new language?" --max-tokens 512

Demo

Below is a demo of Llama-70b (use TinyLlama-1.1b as draft model) running on 4 Nvidia-A100-80G GPUs.

  • Demo w/o speculative decoding:
    demo2

  • Demo with speculative decoding.
    demo

Limitations

This feature is still experimental, so use with caution. The following features (not exhaustive) are currently not supported:

  • stop string
  • beam search
  • scheduler preemption
  • Alibi slopes
  • prefix cache

Acknowledgement:

I'd like to thank @whbldhwj for sharing valuable data and code, especially the first MQA kernel, @YuchengT, @li2haipeng, @KexinFeng, @Vatshank for helpful discussions, @harishneit and @lanking520 for their leadership.

@wongjingping
Copy link

Hi @ymwangg thanks a lot for this PR, I learnt a few things from it. While trying to test it out, I did a pip install -e . on a fresh container forked from nvidia/cuda:12.1.0-devel-ubuntu22.04 with a few other things but I got this weird error where when I tried to import it, I couldn't see any objects or classes from vllm:

import vllm
print(dir(vllm)) # doesn't have the usual LLM, ...

Any idea if I need to build this differently? Curious to know in what environment did you get it to work.

@ymwangg
Copy link
Contributor Author

ymwangg commented Jan 29, 2024

@wongjingping this PR doesn't introduce any other dependencies, so it should work as long as vllm works. Are you able to build and run vanilla vllm?

@LiuXiaoxuanPKU
Copy link
Collaborator

Hi, I'm curious about the performance of the MQA kernel. Is it possible to get some performance number comparison between this and the MQA kernel used by prefix caching (here).

@ymwangg
Copy link
Contributor Author

ymwangg commented Jan 30, 2024

Hi, I'm curious about the performance of the MQA kernel. Is it possible to get some performance number comparison between this and the MQA kernel used by prefix caching (here).

Sure, I'll run some tests on it. Their performance should be very similar. This PR's MQA kernel made a few simplifications based on the assumption that the total query length won't exceed 16 and kv caches already contain all tokens. I did find flash-attention's flash_attn_varlen_func is a little faster than triton but it doesn't make end-to-end faster due to extra overhead in preparing the inputs (flattening kv cache, cu_seqlens_k) and lack of cudagraph support.

@wongjingping
Copy link

@ymwangg thanks for the tip - I got it to work after building it directly within my updated dockerfile. It works great! Only thing I noticed is that the memory consumption will be slightly higher, and one might need to reduce the gpu_memory_utilization parameter to avoid OOM during the graph building. For a codellama-34b + tinyllama on a 4xRTX4090, I found that 0.80 works fine (vs the default 0.90).

@zhuohan123 zhuohan123 mentioned this pull request Jan 31, 2024
30 tasks
@ymwangg
Copy link
Contributor Author

ymwangg commented Jan 31, 2024

@wongjingping glad to hear it now works. Yes, I also observed gpu_memory_utilization needs be <= 0.8 for non-A100 gpus like A10g (23GB HBM).

@leonardxie
Copy link

hi, i ran your code, but it reports import vllm._c error, can you provide the file?

@ymwangg
Copy link
Contributor Author

ymwangg commented Feb 2, 2024

hi, i ran your code, but it reports import vllm._c error, can you provide the file?

You probably didn't install it correctly. The simplest way to install vllm is to use docker

docker run --gpus all -it nvidia/cuda:12.1.0-devel-ubuntu22.04 /bin/bash

and run the following inside container:

apt update && apt install -y python3 pip git && VLLM_INSTALL_PUNICA_KERNELS=0 pip install git+https://github.com/ymwangg/vllm.git@specdec_v0.1.2

@leonardxie
Copy link

thanks, it works!

@xunfeng1980
Copy link

xunfeng1980 commented Feb 4, 2024

Run

python -m vllm.entrypoints.api_server --model /mnt/user/public/llm/qwen/Qwen-14B-Chat-Int4 --tensor-parallel-size 1 --draft-model /mnt/user/public/llm/qwen/Qwen-1_8B-Chat-Int4 --speculate-length 5 --trust-remote-code --quantization gptq

Benchmark

python3 benchmark_serving.py --backend vllm --tokenizer  /mnt/user/public/llm/qwen/Qwen-14B-Chat-Int4 --dataset  /mnt/user/public/llm/qwen/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 200 --request-rate 20 --host 127.0.0.1 --port 8000  --trust-remote-code

Error

Preemption is not supported when using speculative decoding

@ymwangg
Copy link
Contributor Author

ymwangg commented Feb 6, 2024

@xunfeng1980 preemption is currently not supported. You probably can reduce request rate. It works on my machine

(base) ubuntu@:~/src/vllm-aws/benchmarks$ python3 benchmark_serving.py --backend vllm --tokenizer  Qwen/Qwen-14B-Chat-Int4 --dataset  ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 200 --request-rate 20 --host 127.0.0.1 --port 8000  --trust-remote-code
Namespace(backend='vllm', protocol='http', host='127.0.0.1', port=8000, endpoint='/generate', model=None, dataset='ShareGPT_V3_unfiltered_cleaned_split.json', tokenizer='Qwen/Qwen-14B-Chat-Int4', best_of=1, use_beam_search=False, num_prompts=200, request_rate=20.0, seed=0, trust_remote_code=True)
WARNING 02-06 05:49:46 tokenizer.py:64] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
Token indices sequence length is longer than the specified maximum sequence length for this model (20448 > 8192). Running this sequence through the model will result in indexing errors
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:05<00:00,  3.06it/s]Total time: 65.34 s
Throughput: 3.06 requests/s
Average latency: 25.05 s
Average latency per token: 0.08 s
Average latency per output token: 0.32 s

Though it's slower than that w/o speculative decoding, which yields 3.75 requests/s.

@ymwangg
Copy link
Contributor Author

ymwangg commented Feb 6, 2024

@xunfeng1980 Thanks for reporting this issue. It should be fixed now. Pre-emption by recompute is banned erroneously.

@xunfeng1980
Copy link

@xunfeng1980 Thanks for reporting this issue. It should be fixed now. Pre-emption by recompute is banned erroneously.

RTX 4090 vllm with Speculative Decoding

python -m vllm.entrypoints.api_server --model Qwen-14B-Chat-Int4 --tensor-parallel-size 1 --draft-model Qwen-1_8B-Chat-Int4 --speculate-length 5 --trust-remote-code --quantization gptq --gpu-memory-utilization 0.8

Total time: 150.98 s
Throughput: 1.32 requests/s
Average latency: 69.91 s
Average latency per token: 0.24 s
Average latency per output token: 2.14 s

RTX 4090 vllm no Speculative Decoding

python -m vllm.entrypoints.api_server --model Qwen-14B-Chat-Int4  --trust-remote-code --quantization gptq --gpu-memory-utilization 0.8

Total time: 74.73 s
Throughput: 2.68 requests/s
Average latency: 33.50 s
Average latency per token: 0.11 s
Average latency per output token: 0.83 s

Speculative Decoding is more slowly

@mansur20478
Copy link

I have a question about the line 313, vllm/core/scheduler.py
Screenshot 2024-02-07 at 21 33 05
Should it not be can_append_multiple_slots() instead?
Also, I think I am wrong, but when I checked speculate_execute_model() function, does it not generate the first token for prompt using target model?

@ymwangg
Copy link
Contributor Author

ymwangg commented Feb 9, 2024

@mansur20478 Thanks for the good catch. Yes, it's supposed to use "can_append_multiple_slots". It get slipped through during rebasing.
Speculate_execute_model does generate the (n+1)-th token sampled from target model (https://github.com/ymwangg/vllm-aws/blob/1e091790986c411dc935edd548365f07ce5d8923/vllm/worker/model_runner.py#L1042-L1051). The draft model will run these n+1 tokens and generate the (n+2)-th token. Note the (n+2)-th token from the draft model is discarded here.

@chen3933
Copy link

A few questions regarding this PR:

  1. Is there a reason to create a new variable draft_model_tp_size to shard draft_model instead of just using tensor_parallel_size
  2. Probably missing validation when creating configs. For example, we might want to verify both models are using the same tokenizer, etc.
  3. During profile_num_available_blocks, It only run the main model. This might causing OOM on rank 0 when running inference.
  4. I think there is another project also target to enable Speculative decoding (https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit). Is this going to impact on merging this PR?

@ymwangg
Copy link
Contributor Author

ymwangg commented Feb 13, 2024

A few questions regarding this PR:

  1. Is there a reason to create a new variable draft_model_tp_size to shard draft_model instead of just using tensor_parallel_size
  2. Probably missing validation when creating configs. For example, we might want to verify both models are using the same tokenizer, etc.
  3. During profile_num_available_blocks, It only run the main model. This might causing OOM on rank 0 when running inference.
  4. I think there is another project also target to enable Speculative decoding (https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit). Is this going to impact on merging this PR?

Thanks for the question!

  1. The draft model is typically much smaller than the target model, so typically they have different tp degrees. In practice, you rarely need draft_model_tp_size other than 1.
  2. Good point. I think we can throw a warning similar to the ctx_length mismatch.
  3. The assumption here is that running the draft model doesn't affect the peak memory usage since it's much smaller than the target model. I did notice some GPUs need a smaller gpu_memory_utilization. Not sure if this is relevant and I'll double check.
  4. We discussed this offline last week with other folks working on speculative decoding and will try to consolidate our efforts together.

@chen3933
Copy link

A few questions regarding this PR:

  1. Is there a reason to create a new variable draft_model_tp_size to shard draft_model instead of just using tensor_parallel_size
  2. Probably missing validation when creating configs. For example, we might want to verify both models are using the same tokenizer, etc.
  3. During profile_num_available_blocks, It only run the main model. This might causing OOM on rank 0 when running inference.
  4. I think there is another project also target to enable Speculative decoding (https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit). Is this going to impact on merging this PR?

Thanks for the question!

  1. The draft model is typically much smaller than the target model, so typically they have different tp degrees. In practice, you rarely need draft_model_tp_size other than 1.
  2. Good point. I think we can throw a warning similar to the ctx_length mismatch.
  3. The assumption here is that running the draft model doesn't affect the peak memory usage since it's much smaller than the target model. I did notice some GPUs need a smaller gpu_memory_utilization. Not sure if this is relevant and I'll double check.
  4. We discussed this offline last week with other folks working on speculative decoding and will try to consolidate our efforts together.

Thanks for replying. Could you please rebase this PR to 0.3.0 (https://github.com/vllm-project/vllm/tree/v0.3.0) ?

Co-authored-by: Jie Wang <holawj@gmail.com>

Fix greedy sampling in speculative decoding

Add back pre-emption by recompute support

Add logprobs support for speculative decoding.

Fix prompt_logprobs and add stop_str support
@leonardxie
Copy link

hi, i find that after combining the speculative decoding method with vllm, the engine will be initailized with target model and draft model, and then the default generate methed of llmengine will always run with speculative mode. Whether it's possible to determine whether to use speculative mode with a parameter,in the case of concurrent calls to the llmengine's generate method

@mansur20478
Copy link

I have a question, how does speculative decoding works in case of larger batch size.
Does each sequence in batch reject independently from each other?

For example, suppose my batch size is 4 and speculation length is 7.
If it is independent on each other, after speculative decoding step:
Sequence 1 rejects 4th token
Sequence 2 rejects 3rd token
Sequence 3 rejects 6th token
Sequence 4 rejects 2nd token.

If it is dependent on each other, after speculative decoding step:
Sequence 1 rejects 2nd token
Sequence 2 rejects 2nd token
Sequence 3 rejects 2nd token
Sequence 4 rejects 2nd token.
Because of the 4th Sequence, since it is rejected on 2nd token, it will not proceed further with other sequences.

@ymwangg
Copy link
Contributor Author

ymwangg commented Mar 11, 2024

I have a question, how does speculative decoding works in case of larger batch size. Does each sequence in batch reject independently from each other?

For example, suppose my batch size is 4 and speculation length is 7. If it is independent on each other, after speculative decoding step: Sequence 1 rejects 4th token Sequence 2 rejects 3rd token Sequence 3 rejects 6th token Sequence 4 rejects 2nd token.

If it is dependent on each other, after speculative decoding step: Sequence 1 rejects 2nd token Sequence 2 rejects 2nd token Sequence 3 rejects 2nd token Sequence 4 rejects 2nd token. Because of the 4th Sequence, since it is rejected on 2nd token, it will not proceed further with other sequences.

Each sequence accept/reject tokens independently, the code is here https://github.com/ymwangg/vllm/blob/specdec_v0.1.2/vllm/model_executor/layers/sampler.py#L727-L740.

@dutsc
Copy link

dutsc commented Mar 26, 2024

Hello! Thank you very much for your work! I was very interested in your work, so I fetched #2607 locally for research, but I encountered a similar problem to #1391 when pip install -e . Have you encountered similar problems? My cuda version is 12.1, the torch version is 2.2.0, and the GPU is RTX 3090.

@Heelim-Hong
Copy link

[Question on Increasing Single Decoding Time with Speculative Decoding as Batch Size Increases]

I am exploring the impact of speculative decoding on the efficiency of very large language models (vLLMs) and have observed some intriguing behavior regarding decoding times.

From my experiments, I noticed that when speculative decoding is not used, the single decoding time remains relatively stable across different batch sizes. However, when speculative decoding is implemented, there is a significant increase in single decoding time as the batch size increases.

For context, the speculative decoding setup I am using involves a draft model with a speculative length set to 4. I have attached a table below that illustrates these observations:

스크린샷 2024-04-07 오후 8 01 44

I understand that the overall computational load increases with speculative decoding due to the use of a draft model. However, I am curious about the specific reasons why the increase in single decoding time is notably pronounced with larger batch sizes. Could this be related to the overhead from running verification processes in parallel on the target model?

I used LLaMa-13B for target model and LLaMa-68M for draft model. Four A100(80GB) GPU are used with TP degree 4.

Any insights or explanations would be greatly appreciated.

Thank you for your time and assistance.

@ymwangg
Copy link
Contributor Author

ymwangg commented Apr 10, 2024

@dutsc it looks like this issue is specific to windows os. Sorry I don't have access to windows setup. Maybe other folks in the community can help you.

@ymwangg
Copy link
Contributor Author

ymwangg commented Apr 10, 2024

Hi @Heelim-Hong, it's expected the speedup with speculative decoding keep decreasing as you increase the batch sizes. Yes, this is related to how verification works. At low level, you can think about multiplying two matrices of shape [b, m, k] * [k, n] where k and n are large. Normal decoding uses [b, 1, k] * [k, n] to decode b tokens while speculative decoding uses [b, m, k] * [k, n] to decode 1 to m tokens. When b * m is small, it can benefits from memory-bound characteristics where O([b, m, k]*[k, n]) =~ O([b, 1, k]*[k, n]). For larger b * m, this no longer holds and therefore you see lower speedups with speculative decoding.

@Heelim-Hong
Copy link

Hi @Heelim-Hong, it's expected the speedup with speculative decoding keep decreasing as you increase the batch sizes. Yes, this is related to how verification works. At low level, you can think about multiplying two matrices of shape [b, m, k] * [k, n] where k and n are large. Normal decoding uses [b, 1, k] * [k, n] to decode b tokens while speculative decoding uses [b, m, k] * [k, n] to decode 1 to m tokens. When b * m is small, it can benefits from memory-bound characteristics where O([b, m, k]*[k, n]) =~ O([b, 1, k]*[k, n]). For larger b * m, this no longer holds and therefore you see lower speedups with speculative decoding.

Hi @ymwangg. Thank you very much for your response. I see that b represents the batch size and m represents the speculation length, but what do k and n represent?

@simon-mo simon-mo closed this Oct 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants