Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce speculative decoding with draft models to vLLM #3029

Closed
wants to merge 3 commits into from

Conversation

sighingnow
Copy link
Contributor

@sighingnow sighingnow commented Feb 25, 2024

This PR is (yet another) implementation of speculative decoding on vLLM, compared with existing efforts (including #2607, #1797, and #1679), this PR:

  • completed, confirmed workable (and correct)
  • might be the most up-to-date implementation (based on current main)
  • simplest
  • support (paged) kv-cache for draft model as well
  • compatible with existing features, including paged kv-cache and the prefix cache feature

(Note that this PR is built-upon the PR #3007 (GQA fixes for context_attention_fwd) and #3010 (introducing flash-attn to vLLM), but not depends on these two PRs. If these two PRs can be accepted, I would rebase this PR then, otherwise submitting the speculative sampling feature in a separate PR also works for me.)

The major change only happens in llm_engine.py's step() method, model_runner.py's _prepare_decode()/_prepare_sample() method, should be fairly easy for code review.

The major design & implementation can be highlighted as follows:

  • KV cache:
    • for target model and draft model: the usable memory for KV cache is splitted by respecting the per-token kv-cache size's proportion of the target and draft model.
      • i.e., the num_gpu_blocks is computed from gpu_memory / (draft_block_size + target_block_size)
    • The draft model and target block shares the same block table and slot mapping, simplifying implementation in scheduler
  • The scheduler is responsible for choosing sequences to run (as usually), and the LLMEngine's step() method will run draft model for k times and then a target model step follows.
  • Attention:
    • Specualtive decoding requires verifying k tokens at the same time, this is implemented using the context_attention_fwd kernel (originally added for prefix caching)
    • Compatible with flash-attn's flash_attn_with_kvcache kernel.
  • Sampling:

TODO:

  • The interaction of speculative decoding and beam search is a bit complicate (see also the original ICML paper's appendix 4). This PR raising NotImplementedError for such case and leave it as a TODO.
  • Support initialize draft workers on Ray.

Numbers of a prompt randomly choosed from dataset, using Llama-2-70B-GPTQ as the target model and Llama-2-7B-GPTQ/TinyLLama-1.1B-Chat-v1.0-GPTQ as the draft model:

Model DraftModel Quantization Prompt Generation BatchSize Lookahead Drafted Accepted AcceptRate TG Speed (tokens/second) TG Latency (seconds)
70B 1.1B GPTQ 10 200 1 1 0 0 0 33.8 0.0295858
70B 7B GPTQ 10 200 1 1 0 0 0 33.9 0.02949853
70B 1.1B GPTQ 10 200 1 2 109 85 77 43.3 0.02309469
70B 7B GPTQ 10 200 1 2 107 93 86 42.8 0.02336449
70B 1.1B GPTQ 10 200 1 3 162 113 69 53.1 0.01883239
70B 7B GPTQ 10 200 1 3 149 121 81 51.4 0.01945525
70B 1.1B GPTQ 10 200 1 4 197 116 58 53.6 0.01865672
70B 7B GPTQ 10 200 1 4 203 150 73 47.8 0.0209205
70B 1.1B GPTQ 10 200 1 5 234 136 58 52.3 0.01912046
70B 7B GPTQ 10 200 1 5 205 166 80 52.6 0.01901141
70B 1.1B GPTQ 10 200 1 6 262 153 58 47 0.0212766
70B 7B GPTQ 10 200 1 6 260 180 69 42 0.02380952
70B 1.1B GPTQ 10 200 1 7 295 154 52 46.2 0.02164502
70B 7B GPTQ 10 200 1 7 269 173 64 45.7 0.02188184
70B 1.1B GPTQ 10 200 1 8 707 268 38 39.7 0.02518892
70B 7B GPTQ 10 200 1 8 274 184 67 46.9 0.02132196
70B 1.1B GPTQ 10 200 64 1 0 0 0 283 0.22614841
70B 7B GPTQ 10 200 64 1 0 0 0 283 0.22614841
70B 1.1B GPTQ 10 200 64 2 31872 24621 78 427.7 0.1496376
70B 7B GPTQ 10 200 64 2 29632 26226 90 422.9 0.15133601
70B 1.1B GPTQ 10 200 64 3 41088 29947 76 578.1 0.11070749
70B 7B GPTQ 10 200 64 3 36608 29547 83 490.5 0.1304791
70B 1.1B GPTQ 10 200 64 4 41344 25872 63 610.7 0.10479777
70B 7B GPTQ 10 200 64 4 36544 26938 76 562.2 0.11383849
70B 1.1B GPTQ 10 200 64 5 45760 27635 62 640.7 0.09989074
70B 7B GPTQ 10 200 64 5 40704 30531 76 548.3 0.11672442
70B 1.1B GPTQ 10 200 64 6 55872 28106 51 661.3 0.09677907
70B 7B GPTQ 10 200 64 6 47040 31387 69 597 0.10720268
70B 1.1B GPTQ 10 200 64 7 59520 26615 50 654.3 0.09781446
70B 7B GPTQ 10 200 64 7 66560 40516 64 612 0.10457516
70B 1.1B GPTQ 10 200 64 8 65408 26316 44 710.4 0.09009009
70B 7B GPTQ 10 200 64 8 71872 41465 60 593.1 0.1079076

@sighingnow sighingnow force-pushed the ht/speculative branch 10 times, most recently from 2585153 to 5529f29 Compare February 25, 2024 12:01
@nivibilla
Copy link

Hi, this doesn't seem to work with ray atm. I was trying to benchmark Mixtral8x7b with Mistral 7b as the draft model.

AssertionError: Speculative decoding is not supported with Ray.

@sighingnow
Copy link
Contributor Author

sighingnow commented Feb 25, 2024

Hi, this doesn't seem to work with ray atm. I was trying to benchmark Mixtral8x7b with Mistral 7b as the draft model.

AssertionError: Speculative decoding is not supported with Ray.

Haven't test on ray environment yet (lack of worker initialization steps), fixing that shouldn't be complicated, though.

Ray worker initialization added.

@nivibilla
Copy link

Hi, this doesn't seem to work with ray atm. I was trying to benchmark Mixtral8x7b with Mistral 7b as the draft model.
AssertionError: Speculative decoding is not supported with Ray.

Haven't test on ray environment yet (lack of worker initialization steps), fixing that shouldn't be complicated, though.

Ah I see, np. I was trying to test with tp 8 on 8xA10 setup.

@nivibilla
Copy link

also im getting this error when trying to run

ModuleNotFoundError: No module named 'vllm.worker.spec_decode'

Im installing like

!pip install git+https://github.com/sighingnow/vllm.git@ht/speculative

@sighingnow
Copy link
Contributor Author

ModuleNotFoundError: No module named 'vllm.worker.spec_decode'

Should be fixed now. Ray worker initialization has been added as well (not tested on ray, but it basically shares the same logic as before).

@nivibilla
Copy link

nivibilla commented Feb 25, 2024

Just tried this again, with tp 8. It freezes after

Namespace(model='/local_disk0/LoneStriker/Mixtral-8x7B-Instruct-v0.1-HF', draft_model='/local_disk0/mistralai/Mistral-7B-Instruct-v0.2', tokenizer='/local_disk0/LoneStriker/Mixtral-8x7B-Instruct-v0.1-HF', quantization=None, tensor_parallel_size=8, input_len=32, output_len=256, batch_size=1, n=1, use_beam_search=False, temperature=0.0, num_iters=3, trust_remote_code=False, dtype='auto', enforce_eager=True, kv_cache_dtype='auto', profile=False, profile_result_dir=None, device='cuda', use_flash_attn=True, parallel_decoding_lookahead=5)
INFO 02-25 14:00:57 config.py:433] Custom all-reduce kernels are temporarily disabled due to stability issues. We will re-enable them once the issues are resolved.
2024-02-25 14:00:59,910	INFO worker.py:1724 -- Started a local Ray instance.
INFO 02-25 14:01:11 llm_engine.py:83] Initializing an LLM engine with config: model='/local_disk0/LoneStriker/Mixtral-8x7B-Instruct-v0.1-HF', draft_model='/local_disk0/mistralai/Mistral-7B-Instruct-v0.2', tokenizer='/local_disk0/LoneStriker/Mixtral-8x7B-Instruct-v0.1-HF', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=8, disable_custom_all_reduce=True, quantization=None, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, seed=0)

Something does seem to be loaded into memory though

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A10G                    Off | 00000000:00:16.0 Off |                    0 |
|  0%   24C    P0              56W / 300W |  11558MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A10G                    Off | 00000000:00:17.0 Off |                    0 |
|  0%   25C    P0              57W / 300W |  11558MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   2  NVIDIA A10G                    Off | 00000000:00:18.0 Off |                    0 |
|  0%   25C    P0              57W / 300W |  11558MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   3  NVIDIA A10G                    Off | 00000000:00:19.0 Off |                    0 |
|  0%   24C    P0              56W / 300W |  11558MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   4  NVIDIA A10G                    Off | 00000000:00:1A.0 Off |                    0 |
|  0%   26C    P0              57W / 300W |  11558MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   5  NVIDIA A10G                    Off | 00000000:00:1B.0 Off |                    0 |
|  0%   25C    P0              56W / 300W |  11558MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   6  NVIDIA A10G                    Off | 00000000:00:1C.0 Off |                    0 |
|  0%   26C    P0              59W / 300W |  11558MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   7  NVIDIA A10G                    Off | 00000000:00:1D.0 Off |                    0 |
|  0%   26C    P0              56W / 300W |  11558MiB / 23028MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

Im running the benchmark file

!python3 /local_disk0/vllm/benchmarks/benchmark_latency.py \
      --model /local_disk0/LoneStriker/Mixtral-8x7B-Instruct-v0.1-HF \
      --draft-model /local_disk0/mistralai/Mistral-7B-Instruct-v0.2 \
      --tokenizer /local_disk0/LoneStriker/Mixtral-8x7B-Instruct-v0.1-HF \
      --tensor-parallel-size 8 \
      --temperature 0 \
      --num-iters 3 \
      --input-len 32 \
      --output-len 256 \
      --batch-size 1 \
      --use-flash-attn \
      --parallel-decoding-lookahead 5 \
      --enforce-eager

@sighingnow
Copy link
Contributor Author

sighingnow commented Feb 25, 2024

Just tried this again, with tp 8. It freezes after

Thanks for the feedback. Could you please help to confirm that

  • without specualtive decoding (and the draft model), the model Mixtral-8x7B-Instruct-v0.1-HF can work as expected?
  • without ray, could smaller models work on a single GPU?

Thanks!

@nivibilla
Copy link

Yes sure.

Main Branch

Mixtral 8x7b
TP 8
Avg latency: 8.645047865666735 seconds
Gemma 2b
TP 1
Avg latency: 3.522750368333618 seconds

On PR

Mixtral 8x7b + Flash Attn
TP 8
Avg latency: 9.122520502666399 seconds
Mxtral 8x7b w/o Flash Attn
TP 8
Avg latency: 8.954343442333387 seconds
Gemma 2b + Flash Attn
TP 1
Avg latency: 3.2953348303332555 seconds
Gemma 2b w/o Flash Attn
TP 1
Error

Traceback (most recent call last):
  File "/local_disk0/vllm/benchmarks/benchmark_latency.py", line 163, in <module>
    main(args)
  File "/local_disk0/vllm/benchmarks/benchmark_latency.py", line 71, in main
    run_to_completion(profile_dir=None)
  File "/local_disk0/vllm/benchmarks/benchmark_latency.py", line 63, in run_to_completion
    llm.generate(prompt_token_ids=dummy_prompt_token_ids,
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 187, in generate
    return self._run_engine(use_tqdm)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 213, in _run_engine
    step_outputs = self.llm_engine.step()
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 983, in step
    all_outputs: List[SamplerOutput] = self._run_workers(
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 1385, in _run_workers
    driver_worker_output = getattr(driver_worker,
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/worker/worker.py", line 223, in execute_model
    output = self.model_runner.execute_model(seq_group_metadata_list,
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 625, in execute_model
    hidden_states = model_executable(
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/model_executor/models/gemma.py", line 269, in forward
    hidden_states = self.model(input_ids, positions, kv_caches,
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/model_executor/models/gemma.py", line 237, in forward
    hidden_states, residual = layer(
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/model_executor/models/gemma.py", line 189, in forward
    hidden_states = self.self_attn(
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/model_executor/models/gemma.py", line 141, in forward
    attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/model_executor/layers/attention.py", line 249, in forward
    context_attention_fwd(
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/local_disk0/.ephemeral_nfs/envs/pythonEnv-b98302a4-06cd-438a-b491-566113098b11/lib/python3.10/site-packages/vllm/model_executor/layers/triton_kernel/prefix_prefill.py", line 639, in context_attention_fwd
    assert Lk in {16, 32, 64, 128}
AssertionError

@nivibilla
Copy link

nivibilla commented Feb 25, 2024

Wait it just worked this time (EDIT, this is lookahead 1 so not using spec decode)

MIxtral 8x7b + Draft Mistral 7b
TP 8
Avg latency: 9.251173579999886 seconds

Will test some more combinations

@sighingnow
Copy link
Contributor Author

Wait it just worked this time

MIxtral 8x7b + Draft Mistral 7b
TP 8
Avg latency: 9.251173579999886 seconds

Will test some more combinations

Thanks for reporting! I have just added some my previous experiment result to PR description.

@nivibilla
Copy link

When I use

--parallel-decoding-lookahead 5 \

It freezes. Using lookahead 1 works fine

@sighingnow
Copy link
Contributor Author

sighingnow commented Feb 25, 2024

When I use

--parallel-decoding-lookahead 5 \

It freezes. Using lookahead 1 works fine

Using lookahead 1 works fine will disable speculative decoding, and then the draft model won't be loaded.

Have you tested if it works with using the Mistral 7b as the only main model, without speculative decoding?

@nivibilla
Copy link

In your experience, how long does the benchmark take?
Could it be that im not waiting long enough? It stops at initialising engine for 10mins.

@sighingnow
Copy link
Contributor Author

sighingnow commented Feb 25, 2024

In your experience, how long does the benchmark take? Could it be that im not waiting long enough? It stops at initialising engine for 10mins.

The main model and target model shares the same code path, thus there shouldn't be many differences in model loading (except the model size).

@nivibilla
Copy link

I think there's something wrong with my setup. It stops after loading one model.

@sighingnow
Copy link
Contributor Author

I think there's something wrong with my setup. It stops after loading one model.

I cannot find a 8 GPU machine to run un-quantized MIxtral 8x7b and Mistral 7b model, but I can confirm that using the GPTQ version --model ~/models/Mixtral-8x7B-Instruct-v0.1-GPTQ --draft-model ~/models/Mistral-7B-Instruct-v0.2-GPTQ works with a single 80Gi GPU.

Signed-off-by: Tao He <sighingnow@gmail.com>
Signed-off-by: Tao He <sighingnow@gmail.com>
Signed-off-by: Tao He <sighingnow@gmail.com>

if input_metadata.use_flash_attn:
# see also: https://github.com/Dao-AILab/flash-attention/commit/54e80a3829c6d2337570d01e78ebd9529c02d342
output = flash_attn_with_kvcache(
Copy link
Contributor

@ymwangg ymwangg Feb 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing the implementation. Is flash_attn_with_kvcache faster than context_attention_fwd?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't benchmark these two kernels, but I have added a end-to-end speculative sampling benchmark in the last part of #3010's pull request description.

@zhaoyang-star
Copy link
Contributor

I met the same issue when tp=4. It freeze after

Namespace(model='/robot-shared/models/huggingface/CodeLlama-70b-hf--4570a4e-C6', draft_model='/robot-shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23', tokenizer=None, quantization=None, tensor_parallel_size=4, input_len=512, output_len=512, batch_size=1, n=1, use_beam_search=False, temperature=1.0, num_iters=10, trust_remote_code=False, dtype='auto', enforce_eager=True, kv_cache_dtype='auto', profile=False, profile_result_dir=None, device='cuda', use_flash_attn=False, parallel_decoding_lookahead=4)
INFO 02-28 17:15:36 config.py:433] Custom all-reduce kernels are temporarily disabled due to stability issues. We will re-enable them once the issues are resolved.
2024-02-28 17:15:38,084	WARNING utils.py:575 -- Detecting docker specified CPUs. In previous versions of Ray, CPU detection in containers was incorrect. Please ensure that Ray has enough CPUs allocated. As a temporary workaround to revert to the prior behavior, set `RAY_USE_MULTIPROCESSING_CPU_COUNT=1` as an env var before starting Ray. Set the env var: `RAY_DISABLE_DOCKER_CPU_WARNING=1` to mute this warning.
2024-02-28 17:15:39,230	INFO worker.py:1724 -- Started a local Ray instance.
INFO 02-28 17:15:40 llm_engine.py:83] Initializing an LLM engine with config: model='/robot-shared/models/huggingface/CodeLlama-70b-hf--4570a4e-C6', draft_model='/robot-shared/models/huggingface/CodeLlama-7b-hf--3773f63-C23', tokenizer='/robot-shared/models/huggingface/CodeLlama-70b-hf--4570a4e-C6', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=16384, download_dir=None, load_format=auto, tensor_parallel_size=4, disable_custom_all_reduce=True, quantization=None, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, seed=0)

The command is as following:

                python3 benchmarks/benchmark_latency.py \
                    --model ${model} \
                    --draft-model ${draft_model} \
                    --temperature 1.0 \
                    --parallel-decoding-lookahead 4 \                                                                                                                                                                                         
                    --enforce-eager \
                    --tensor-parallel-size ${tp} \
                    --input-len ${lens} \
                    --output-len ${lens} \
                    --batch-size ${bs} \
                    --num-iters ${iters} 2>&1|tee -a ${log_file}

The command works fine with --parallel-decoding-lookahead 1. So I guess it may still have some bug when tp>1.

@zhaoyang-star
Copy link
Contributor

The data is interesting. I think sampling params such as temperature, frequency_penalty will infect the accept rate. Could you share your full config and sampling params?

@zhaoyang-star
Copy link
Contributor

Another concern: Speculative decoding will yield lower TPOT while higher TTFT. Is it possible to skip the first generation token? To make it more clear, the first generation token is generated by target model only. Speculative decoding is only applied to 2nd and the following tokens.

@sighingnow
Copy link
Contributor Author

So I guess it may still have some bug when tp>1.

Will take a try.

Could you share your full config and sampling params?

I just set temperature for better reproducibility in benchmarks. From my experiences, temperature=0.0 and temperature=1.0 differs a bit in finally accept rate, but not much.

To make it more clear, the first generation token is generated by target model only. Speculative decoding is only applied to 2nd and the following tokens.

It might cannot be archived, I think. As the autoregressive steps of draft models requires prompt tokens' kv-cache.

There exists many other speculative decoding design, e.g., self-speculative decoding, which doesn't require a draft model. Such decoding strategy is not included in this PR but is compatible with current design.

@sighingnow
Copy link
Contributor Author

The data is interesting. I think sampling params such as temperature, frequency_penalty will infect the accept rate. Could you share your full config and sampling params?

The accept rate varies between datasets as well.

@zhaoyang-star
Copy link
Contributor

@sighingnow Thanks for your reply. The main block for me is the TTFT. So I have to find some other ways to solve it.

@sighingnow
Copy link
Contributor Author

@sighingnow Thanks for your reply. The main block for me is the TTFT. So I have to find some other ways to solve it.

If I understand it correctly,

  • speculative decoding can maintain the same TTFT, by prefilling the target model first, and returning the token as soon as been generated, before draft model's prefilling, by introducing some kind of async.
  • I think prefix cache may help for TTFT.

@zhaoyang-star
Copy link
Contributor

@sighingnow Hi Tao, have you solved the tp>1 error? As speculative decoding is mainly used for models with large params, which tensor parallel is needed. Looking forward to your good news :)

@zhaoyang-star
Copy link
Contributor

I think prefix cache may help for TTFT.

Automatic prefix cache #2762 has been merged but there is not performance increase. The performance issue is in TODO list.

@sighingnow
Copy link
Contributor Author

@sighingnow Hi Tao, have you solved the tp>1 error? As speculative decoding is mainly used for models with large params, which tensor parallel is needed. Looking forward to your good news :)

Haven't tried yet. I would try to reproduce.

At the same time, I have noticed @cadedaniel has submitted [3/9] PR (#3103) for speculative decoding, I would like to know how the vLLM community think about the plan of the development plan of speculative decoding. If this PR is confirmed won't be accepted/merged or even partially accepted/merged, refine it may won't help much in any way.

@cadedaniel
Copy link
Collaborator

At the same time, I have noticed @cadedaniel has submitted [3/9] PR (#3103) for speculative decoding, I would like to know how the vLLM community think about the plan of the development plan of speculative decoding. If this PR is confirmed won't be accepted/merged or even partially accepted/merged, refine it may won't help much in any way.

Thanks for the work on this -- it's a good PR :). To answer your question, @simon-mo @LiuXiaoxuanPKU, myself, and the authors of #2607 met a few weeks ago and @simon-mo wants to go with #2188 first. The key idea is that it refactors some vLLM internals so that different types of speculative decoding is supported; e.g. prompt-lookup, RAG acceleration, topk+top1 Eagle/Medusa, typical acceptance, etc.

I'll be working full-time this week to finish the PRs. After the correctness tests are in, happy to collaborate/accept optimizations.

@sighingnow
Copy link
Contributor Author

I'll be working full-time this week to finish the PRs. After the correctness tests are in, happy to collaborate/accept optimizations.

Thank you, @cadedaniel! ❤️ Looking forward to the progress of #2188 and this RP actually benefits from it as well. I will see if there are still something that is worth submitting to vLLM after this series of PRs been merged.

@UranusSeven
Copy link
Contributor

@sighingnow Thanks for your awesome work! I'm recently working on integrating https://github.com/SafeAILab/EAGLE with vLLM. EAGLE leverages tree attention to score the draft tokens, which applies an attention mask (different from the causal attention mask) to attention scores. However, current attention implementation, including flash-attn and flashinfer, seems not able to handle that. Do you have any suggestion? Any guidance you can offer would be greatly appreciated.

@sighingnow
Copy link
Contributor Author

However, current attention implementation, including flash-attn and flashinfer, seems not able to handle that. Do you have any suggestion? Any guidance you can offer would be greatly appreciated.

For tree-attention, one possible way might be forking the tree to several sequences as a batch (their prefix shares the kv-cache)?

I think hacking into the triton kernel context_attention_fwd is possible as well to avoid fork sequences and copy kv-cache blocks.

@sighingnow
Copy link
Contributor Author

I think prefix cache may help for TTFT.

Automatic prefix cache #2762 has been merged but there is not performance increase. The performance issue is in TODO list.

This might caused by cache miss and the inefficiency of the triton kernel. Flash-attn may help.

@UranusSeven
Copy link
Contributor

However, current attention implementation, including flash-attn and flashinfer, seems not able to handle that. Do you have any suggestion? Any guidance you can offer would be greatly appreciated.

For tree-attention, one possible way might be forking the tree to several sequences as a batch (their prefix shares the kv-cache)?

I think hacking into the triton kernel context_attention_fwd is possible as well to avoid fork sequences and copy kv-cache blocks.

True, and that's what I'm doing with flashinfer. Thanks for your advice :)

"--parallel-decoding-lookahead",
type=int,
default=1,
help="Number of lookahead steps for speculativespeculative decoding.")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double spell

@sighingnow sighingnow closed this Jun 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants