Skip to content

[Core/Bugfix] Add FP8 K/V Scale and dtype conversion for prefix/prefill Triton Kernel #7208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

jon-chuang
Copy link
Contributor

@jon-chuang jon-chuang commented Aug 6, 2024

Fix the FP8 Triton kernel issue. Should enable FP8 KV Cache to be used with:

  1. chunked prefill
  2. prefix caching

FIX #4381 #3880 #3156 #3880

TODO:

Notes:

  1. @comaniac mentions upcoming flashinfer support, but I think supporting it in Triton as fallback is good for users who want to avoid installing heavyweight dependency
  2. Regarding correctness, you can see that this PR merely brings up triton kernel to parity with vLLM's custom torch CUDA paged attention kernels:
    k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
    ,
    v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
    , where scaled_convert is defined as:
    return float_to_half(half_to_float(tmp.x) * scale);

Example Output (decode v.s. chunked prefill with FP8 KV Cache):
image

Performance is similar to decode in low load case on max_chunk=16,max_sequence_len=512
image
image

I attribute the slightly lower perf to lack of perf tuning of the triton kernels

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

CC: @comaniac

@jon-chuang jon-chuang marked this pull request as draft August 6, 2024 15:07
Copy link

github-actions bot commented Aug 6, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 6, 2024
@jon-chuang jon-chuang marked this pull request as ready for review August 6, 2024 15:24
@jon-chuang
Copy link
Contributor Author

@comaniac how can I trigger the CI? I have no dev env for vllm currently

@jon-chuang jon-chuang changed the title [Core/Bugfix] Add FP8 K/V Scale and dtype conversion for prefix/prefill [Core/Bugfix] Add FP8 K/V Scale and dtype conversion for prefix/prefill Triton Kernel Aug 6, 2024
@comaniac comaniac self-assigned this Aug 6, 2024
@comaniac
Copy link
Collaborator

comaniac commented Aug 6, 2024

The CI is already triggered.

@jon-chuang
Copy link
Contributor Author

I guess it is waiting for a runner to schedule buildkite/ci-aws/pr?

@jon-chuang
Copy link
Contributor Author

Turns out I actually needed to push new commit after converting from draft status to trigger CI

@comaniac
Copy link
Collaborator

comaniac commented Aug 6, 2024

@comaniac how can I trigger the CI? I have no dev env for vllm currently

Does that mean you cannot verify this PR locally? We should avoid using CI to verify and debug because it's slow and costly.

@comaniac comaniac removed the ready ONLY add when PR is ready to merge/full CI is needed label Aug 6, 2024
@jon-chuang jon-chuang marked this pull request as draft August 6, 2024 18:36
@jon-chuang
Copy link
Contributor Author

jon-chuang commented Aug 6, 2024

Hmm, actually; shouldn't Triton have support for mixed FP8 matmul (E5M2 only)?

Anw, this is a story for next time; as we first need to ask users if they want to enable lower precision matmul with one operand in FP8

EDIT:
This is almost certainly something we don't want to use; it upcasts to FP16/BF16 without any scale factor prior to MMA.

image

@jon-chuang jon-chuang marked this pull request as ready for review August 6, 2024 19:16
Comment on lines 18 to 22
E5M2_KV_MODELS = [ # type: ignore
# does not work with fp8 kv cache kernel
# - CUDA illegal memory access - undiagnosed
# "facebook/opt-125m",
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does other models work with E5M2? For example does that work if you just use the FP16 Qwen2-1.5B-Instruct checkpoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't seem to find any model that has quantized kv E5M2; if not explicitly quantized with KV scales, the log probs differ due to the inaccuracy during low-precision float conversion

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E5M2 doesn't need scaling factors and the accuracy shouldn't drop significantly. If it doesn't make sense to check the logprobs for E5M2, we could skip testing them.

Copy link
Contributor Author

@jon-chuang jon-chuang Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, the reason the test failed was because Qwen2-1.5B is actually BF16 (exponent - 8 bit).
I tried with llama-2-7b which is FP16 (exponent 5 bits) and the tests pass now.

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last batch of minor comments

Comment on lines 117 to 119
if chunked_prefill_token_size != 1:
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when chunked_prefill_token_size==1 you disable chunked prefill. Then isn't the following 2 runners always the same? In this case why we need to test this case (looks like the existing test case also has the same issue)?

Copy link
Contributor Author

@jon-chuang jon-chuang Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I was following the previous test blindly, which is totally wrong as you mentioned. If you want I can get rid of the parameter 1 for both.

Alternately I can also remove the if statement which will allow testing the code path with enable_chunked_prefill=True, prefill_size=1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remove the if statement from both which is useless as you said

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's just make it reasonable. Also cc @rkooo567 who may know what it forms like this.

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jon-chuang
Copy link
Contributor Author

jon-chuang commented Aug 12, 2024

Thank you for help in reviewing @comaniac. Need your help to merge as I have no write access.

@comaniac comaniac enabled auto-merge (squash) August 12, 2024 22:33
@comaniac comaniac merged commit a046f86 into vllm-project:main Aug 12, 2024
55 checks passed
@pavanimajety
Copy link
Contributor

@jon-chuang
Copy link
Contributor Author

Yes, the comment should have been removed

@stenreijers
Copy link

stenreijers commented Aug 19, 2024

can #3234 be closed @jon-chuang ?

@jon-chuang
Copy link
Contributor Author

Yes @chenxu2048

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
…ll Triton Kernel (vllm-project#7208)

Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
…ll Triton Kernel (vllm-project#7208)

Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Chunked prefill doesn't seem to work when --kv-cache-dtype fp8
6 participants