Skip to content

[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE #17211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 29, 2025

Conversation

luyuzhe111
Copy link
Contributor

@luyuzhe111 luyuzhe111 commented Apr 26, 2025

Task 8 of #15901

A few notes regarding the implementation.

Torch.compile

  1. @support_torch_compile is convenient but requires a specific signature for the model. Changes in vllm/model_executor/models/llama_eagle.py address this requirement.
  2. Further, to make torch.compile work, we need a separate cache dir for EAGLE model. Without edits in vllm/compilation/backends.py, we wouldn't be able to cache EAGLE's compilation properly since the EAGLE module was registered under vllm config from the target model (see here)
  3. One notable bug related to torch.compile is this line. Essentially, the default data type for input ids is int32 and the EAGLE model was also compiled with this data type. However, tensor.argmax() returns int64 by default. Feeding int64 input ids to the compiled model will completely mess things up and lead to gibberish draft tokens. Currently the compiled model does not even give warnings when the input data type mismatches. Wonder if we can prevent similar bugs in the future by some more checks.

CudaGraph

Changes in vllm/v1/spec_decode/eagle.py and vllm/v1/worker/gpu_model_runner.py are mostly for CudaGraph. Nothing fancy other than registering additional persistent buffers and making sure to use them for EAGLE's forward pass. I do want to mention that with torch.compile & CudaGraph, the EAGLE model's forward pass has been drastically improved (2.5x faster), which makes the small but abundant torch operations look inefficient. Any advice to further optimize these overheads is greatly appreciated.

Finally, it would be great to have #17010 reviewed and merged so that we don't have to pull in other PRs to test the acceptance length.

Note that the current PR does not directly make torch.compile & cuda graph available for EAGLE3. I think it's worth a separate PR since the work is non-trivial due to the fact that EAGLE-3's input hidden states have dynamic shapes. Maybe @benchislett could help.

@WoosukKwon @LiuXiaoxuanPKU

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Apr 26, 2025
Copy link

mergify bot commented Apr 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @luyuzhe111.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 26, 2025
@luyuzhe111 luyuzhe111 changed the title Apply torch.compile & cudagraph to EAGLE [V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE Apr 26, 2025
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
@mergify mergify bot added the documentation Improvements or additions to documentation label Apr 26, 2025
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
@WoosukKwon
Copy link
Collaborator

Thanks for the PR. This is so cool!
I’ll take a look, but it would be great if we could also get @youkaichao’s review.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luyuzhe111 Thanks for the PR!

Left some comments. Please take a look!

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luyuzhe111 Thanks for addressing the comments! One last thing: Can you please add tests for this?

@luyuzhe111
Copy link
Contributor Author

luyuzhe111 commented Apr 29, 2025

Hi @ekagra-ranjan, thanks for this PR! but wondering if you could create a new PR later that makes sure the draft model takes in vllm_config instead of model_config, if you are interested? This is a requirement for compatibility with torch.compile decorator. for this PR I will just create a condition to handle eagle-1 and eagle-3 models separately. thanks! cc @WoosukKwon

@luyuzhe111
Copy link
Contributor Author

One last thing: Can you please add tests for this?

@WoosukKwon I feel like acceptance length tests are probably the most meaningful tests for this PR. I tested the acceptance length by cherry picking commits from this PR. You can see that the difference is less than 0.01.

With meta-llama/Llama-3.1-8B-Instruct, yuhuili/EAGLE-LLaMA3.1-Instruct-8B

On MT Bench

When max number generated tokens = 256

Number of Speculated Tokens 1 2 3 4 5
Eager 1.71 2.09 2.30 2.38 2.43
Compilation & CudaGraph 1.70 2.10 2.29 2.38 2.43

I can do a follow-up PR adding acceptance length tests after #17010 is merged.

@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Apr 29, 2025

Thank you @luyuzhe111 for the PR!
Wondering if you have any benchmark for speedup in output token/s for with and w/o cuda graph and torch compile for EAGLE llama 3.1 on MTBench TP1 BS1?

I do want to mention that with torch.compile & CudaGraph, the EAGLE model's forward pass has been drastically improved (2.5x faster), which makes the small but abundant torch operations look inefficient.

This is fantastic!
Could you share the script you used to measure just the EAGLE's fwd pass?
By "abundant torch operation", do you mean the torch ops outside the forward pass but within the propose() ?

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
@luyuzhe111
Copy link
Contributor Author

Hey @ekagra-ranjan Re:

  1. I remember you had some numbers for benchmarking on mt bench. can you share the setup so that I can run the benchmarking again for torch.compile + cuda graph? thanks!
  2. for the forward pass I had to look at the profiler, so no script here.
  3. right I meant those operations that prepare inputs for the EAGLE model.

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 29, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @youkaichao Any final comments?

@ekagra-ranjan
Copy link
Contributor

@luyuzhe111 - here is an example setup for benchmark. Looking fwd to the results!

@WoosukKwon WoosukKwon enabled auto-merge (squash) April 29, 2025 21:08
@WoosukKwon WoosukKwon merged commit 70788bd into vllm-project:main Apr 29, 2025
69 checks passed
@luyuzhe111
Copy link
Contributor Author

@ekagra-ranjan I got the following results:

Target model: meta-llama/Llama-3.1-8B-Instruct
EAGLE model: yuhuili/EAGLE-LLaMA3.1-Instruct-8B
Hardware: A100 (40GB)
Script: VLLM_USE_V1=1 python examples/offline_inference/eagle.py —dataset="./data/mt_bench/question.jsonl" —num_spec_tokens x —max_num_seqs 1 —num_prompts 80

Regular Decoding OTPS: 72
Screenshot 2025-04-29 at 3 21 29 PM

it looks like a further 10% speedup.

@zou3519
Copy link
Collaborator

zou3519 commented May 1, 2025

One notable bug related to torch.compile is this line. Essentially, the default data type for input ids is int32 and the EAGLE model was also compiled with this data type. However, tensor.argmax() returns int64 by default. Feeding int64 input ids to the compiled model will completely mess things up and lead to gibberish draft tokens. Currently the compiled model does not even give warnings when the input data type mismatches. Wonder if we can prevent similar bugs in the future by some more checks.

These checks lead to performance degradation so vLLM decided to drop all of them. Some of the lower compilation_levels (e.g. 1) should do the checking but those also have slightly different behavior for other things

Comment on lines +415 to +421
if compilation_counter.num_graphs_seen > 0:
cache_dir = self.compilation_config.cache_dir + \
f'-{compilation_counter.num_graphs_seen}'
else:
cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)
self.compilation_config.cache_dir = cache_dir
Copy link
Collaborator

@zou3519 zou3519 May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luyuzhe111 This is suspicious. Why do you need a different cache directory for each graph? Also, this looks like it modifies everything, even the models that don't use eagle.

If there isn't a good reason I would prefer going back to the "single cache directory" that we had previously.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 thanks for reviewing! if there isn't a separate cache directory, the compiled code for the draft model (EAGLE) will not be saved at all. for models without EAGLE, my understanding is that the backend is invoked only once so this should not impact other models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luyuzhe111 thanks for the response and clarifying that. Woosuk also filled me in on some more details offline. I understand why we need a separate cache directory.

Which of the "original model" and the "eagle head" get compiled first? (I'm trying to figure out if the first cache dir is for the original model or for the eagle head)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 original model should be compiled first! also if you wanna double check, the transformed code of EAGLE in the cache directory has a slightly different signature with hidden_states as an additional arg. if there is a more elegant solution, that would be great! I think my approach is a bit hacky indeed : )))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the discussion! I added some comments and an assertion into #17662 , please take a look.

I think in the future we'll want a better way to handle multiple compiled regions in a vLLM model, but that will take some re-designing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luyuzhe111 the asserts in #17662 triggered, which means that this PR does affect non-eagle models

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Thanks for the catch! I guess a simple fix would be just to create a separate cache directory only for EAGLE, via looking at the vllm speculative config, for example?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would work

radeksm pushed a commit to radeksm/vllm that referenced this pull request May 2, 2025
zou3519 added a commit to zou3519/vllm that referenced this pull request May 9, 2025
I'm recording down my understanding of how eagle and the compilation
cache works after discussing
vllm-project#17211 with @luyuzhe111 and
@WoosukKwon.

In the future we likely will have a situation where we want to
torch.compile multiple pieces of code (e.g. decoder and encoder
separately) and then we'll need to refactor the system to support it
(each compiled region needs its own cache directory with its own hash)
But until then the current design seems fine.

Signed-off-by: rzou <zou3519@gmail.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…ect#17211)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…ect#17211)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
cache_dir = self.compilation_config.cache_dir
if compilation_counter.num_graphs_seen > 0:
cache_dir = self.compilation_config.cache_dir + \
f'-{compilation_counter.num_graphs_seen}'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be able to get the component's prefix to use as the cache directory, it could be more meaningful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make that change, @luyuzhe111 and I were discussing something similar above

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
…ect#17211)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants