-
Notifications
You must be signed in to change notification settings - Fork 264
feat: kv fp8 quant calibration for fa3 and flashinfer #935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @blueswhen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a comprehensive FP8 KV cache quantization feature, enabling significant memory savings during inference. It provides a robust framework for offline calibration to determine optimal quantization scales and seamlessly integrates FP8 support into the high-performance FlashAttention v3 and FlashInfer attention mechanisms. This enhancement aims to improve the efficiency and scalability of large language model serving.
Highlights
- FP8 KV Cache Quantization: This PR introduces support for FP8 (8-bit floating point) quantization of Key (K) and Value (V) caches. This can significantly reduce memory consumption for KV caches, potentially enabling longer context windows or larger batch sizes.
- Offline Calibration Mechanism: A new offline calibration process is added, allowing users to collect activation statistics (absolute maximums) of KV caches during inference. These statistics are then used to compute optimal FP8 quantization scales, which can be exported to a JSON file and loaded for subsequent quantized inference runs.
- Integration with Attention Backends: The FP8 KV cache functionality is deeply integrated into the FlashAttention v3 (FA3) and FlashInfer attention kernels. This ensures that the quantized KV data is correctly handled during both context (prefill) and token decoding phases, including the quantization of query (Q) tensors for FA3.
- Flexible Scaling Options: The implementation supports both per-head and per-tensor quantization scaling for the KV caches, providing flexibility to choose the appropriate granularity for quantization based on model characteristics and performance requirements.
- New Triton Kernels: Dedicated Triton kernels have been developed for efficient FP8 quantization of KV data during the copy operation to the memory buffer, and for per-head FP8 quantization of query tensors.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces FP8 quantization for KV cache, targeting improved memory efficiency. It includes mechanisms for both loading pre-calibrated FP8 scales and performing online calibration. The changes are well-integrated into the existing Llama model inference pipeline, with specific adaptations for FA3 and FlashInfer attention backends. New command-line arguments and environment variables provide control over the calibration process. Unit tests for the new FP8 attention kernels have been added.
Several areas for potential improvement or closer inspection include the strictness of assertions in the new FP8 tests, error message clarity in configuration loading, and ensuring robust handling of scale types (list vs. tensor) when interacting with different kernel backends. The increase in LIGHTLLM_TOKEN_MAX_BYTES
should also be noted for its resource implications.
offline_scales = infer_state.mem_manager.offline_fp8_quant_manager.scales | ||
k_descale = ( | ||
offline_scales[self.layer_num_][: self.tp_k_head_num_].expand(infer_state.batch_size, self.tp_k_head_num_) | ||
if offline_scales is not None | ||
else ones_scales | ||
) | ||
v_descale = ( | ||
offline_scales[self.layer_num_][self.tp_k_head_num_ :].expand(infer_state.batch_size, self.tp_k_head_num_) | ||
if offline_scales is not None | ||
else ones_scales |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In _context_attention_flashattention_fp8
(and similarly in _token_decode_attention_flashattention_fp8
), the offline_scales
are indexed like offline_scales[self.layer_num_][: self.tp_k_head_num_]
. This assumes offline_scales[self.layer_num_]
is a 1D tensor suitable for slicing.
If get_env_start_args().enable_fa3
is false (meaning per-tensor scaling was configured during OfflineFP8QuantManager
initialization), offline_scales[self.layer_num_]
would be a 0D (scalar) tensor. Slicing a scalar tensor like [: self.tp_k_head_num_]
will result in a runtime error.
This suggests that the FA3 FP8 attention path implicitly requires per-head scales (i.e., enable_fa3
must have been true when scales were loaded/calibrated). If calibration_fp8kv
mode can be active with FA3 while enable_fa3
(for scale configuration) is false, this will break.
Consider adding a check or ensuring that the scale format matches the expectation of this FA3-specific kernel. For example, assert that offline_scales[self.layer_num_].dim() == 1
or handle the scalar case appropriately (e.g., by broadcasting if FA3 can work with a single scale for all heads, though that defeats per-head scaling intent).
offline_scales = infer_state.mem_manager.offline_fp8_quant_manager.scales_list | ||
k_descale = offline_scales[self.layer_num_] if offline_scales is not None else None | ||
v_descale = offline_scales[self.layer_num_] if offline_scales is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In _context_attention_flashinfer_kernel_fp8
(and _token_decode_attention_flashinfer_fp8
), offline_scales
is infer_state.mem_manager.offline_fp8_quant_manager.scales_list
. Then k_descale = offline_scales[self.layer_num_]
.
If per-head scaling was used (e.g., enable_fa3=True
during scale loading/calibration), offline_scales[self.layer_num_]
will be a Python list of numbers. The FlashInfer API for k_scale
/v_scale
expects None
, a float, or a torch.Tensor
. Passing a list directly might not work or lead to unexpected behavior.
If FlashInfer might be used with per-head scales from scales_list
, ensure k_descale
and v_descale
are converted to torch.Tensor
before being passed to infer_state.prefill_wrapper.run
or infer_state.decode_wrapper.run
.
Example conversion: k_descale = torch.tensor(offline_scales[self.layer_num_], device=q.device, dtype=torch.float32) if isinstance(offline_scales[self.layer_num_], list) else offline_scales[self.layer_num_]
(adjust dtype/device as needed).
logger.warning("scales is None, no kv_quant_calibration_config_path be set") | ||
|
||
def enable_calibration(self): | ||
assert get_env_start_args().disable_cudagraph, "Calibration is not supported in cudagraph mode" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion assert get_env_start_args().disable_cudagraph, "Calibration is not supported in cudagraph mode"
is a good runtime check. It might be beneficial to also log this information as a warning or info when calibration mode is entered, so users are aware of this limitation proactively if they happen to have cudagraph enabled elsewhere.
lightllm/common/mem_manager.py
Outdated
if get_env_start_args().enable_fa3: | ||
kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32) | ||
else: | ||
kv_max = kv_buffer.abs().amax(dim=()).to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When get_env_start_args().enable_fa3
is true, kv_max
is calculated as kv_buffer.abs().amax(dim=(0, 2))
. If kv_buffer
has a shape like (current_batch_seq_len, num_kv_heads, head_dim)
, then amax(dim=(0, 2))
results in a tensor of shape (num_kv_heads,)
. This seems correct for per-head quantization.
When enable_fa3
is false, kv_max
is kv_buffer.abs().amax(dim=())
, which is a scalar (per-tensor). This also seems correct.
This logic relies on kv_buffer
passed to update_calibration_data
having a consistent shape and representing the K and V values for which scales are being calibrated. Ensure that the buffer
indeed corresponds to the structure expected by these amax
operations (e.g., if it's per-head, it should contain data for all heads that need individual scaling factors).
if dist.is_initialized() and dist.get_world_size() > 1: | ||
if get_env_start_args().enable_fa3: | ||
k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1) | ||
k_max = k_max.contiguous() | ||
v_max = v_max.contiguous() | ||
gathered_k_max = [torch.zeros_like(k_max) for _ in range(dist.get_world_size())] | ||
gathered_v_max = [torch.zeros_like(v_max) for _ in range(dist.get_world_size())] | ||
dist.all_gather(gathered_k_max, k_max, group=None, async_op=False) | ||
dist.all_gather(gathered_v_max, v_max, group=None, async_op=False) | ||
k_max = torch.cat(gathered_k_max, dim=-1) | ||
v_max = torch.cat(gathered_v_max, dim=-1) | ||
final_abs_max = torch.cat((k_max, v_max), dim=-1) | ||
else: | ||
dist.all_reduce(self.abs_max, op=dist.ReduceOp.MAX, group=None, async_op=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The distributed logic for aggregating abs_max
values seems correct:
- For FA3 (
enable_fa3=True
), it gathers K and V maxes separately across ranks and then concatenates them. This assumesself.abs_max
on each rank stores[layer_num, 2 * head_num_per_rank]
andhead_num
(passed to__init__
) is the per-rank head count. - Otherwise (per-tensor scaling), it uses
all_reduce
withMAX
.
This logic appears sound under the assumption that head_num
provided to OfflineFP8QuantManager
is the local head number for the current TP rank.
@@ -2,7 +2,7 @@ | |||
import ctypes | |||
from typing import Tuple | |||
|
|||
LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 128)) | |||
LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 1024)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value for LIGHTLLM_TOKEN_MAX_BYTES
has been increased from 128 to 1024. This is a significant increase and will impact the shared memory allocation for the token queue. While this might be necessary for handling larger tokens (e.g., from multimodal inputs or complex special tokens), it's worth noting the resource implication. Ensure this change is intentional and tested for its impact on overall memory usage, especially in environments with many concurrent requests or limited shared memory.
LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 1024)) | |
LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 1024)) # Increased from 128 |
# assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) | ||
cos_sim1 = F.cosine_similarity(o, o1).mean() | ||
print(cos_sim1) | ||
assert cos_sim1.item() == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion assert cos_sim1.item() == 1
is extremely strict for comparing FP8 quantized operations against higher precision ones. FP8 quantization inherently introduces some loss of precision. Achieving perfect cosine similarity (1.0) usually implies either the test data is such that quantization errors are negligible/cancel out, or the reference path is also effectively operating at a similar precision for the comparison.
Consider using torch.allclose
with appropriate atol
and rtol
values, or asserting that the cosine similarity is very close to 1 (e.g., assert cos_sim1.item() > 0.99
). This would make the test more robust to minor, expected numerical differences. The commented-out torch.allclose
line (147) suggests this was previously considered.
# assert torch.allclose(o, o1, atol=1e-2, rtol=2e-1) | ||
cos_sim1 = F.cosine_similarity(o, o1).mean() | ||
print(cos_sim1) | ||
assert cos_sim1 == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) | ||
cos_sim1 = F.cosine_similarity(o, o1).mean() | ||
print(cos_sim1) | ||
assert cos_sim1 == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
cos_sim1 = F.cosine_similarity(o, o1).mean() | ||
print(cos_sim1) | ||
assert cos_sim1 == 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion assert cos_sim1 == 1.0
is highly stringent for comparing FP8 quantized results with a reference. It's advisable to use torch.allclose
with appropriate absolute and relative tolerances, or to assert that cos_sim1
is very close to 1.0 (e.g., cos_sim1.item() > 0.99
), to account for the inherent precision differences with FP8 arithmetic.
92c274a
to
3b46426
Compare
No description provided.