Skip to content

feat: kv fp8 quant calibration for fa3 and flashinfer #935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

blueswhen
Copy link
Collaborator

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @blueswhen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a comprehensive FP8 KV cache quantization feature, enabling significant memory savings during inference. It provides a robust framework for offline calibration to determine optimal quantization scales and seamlessly integrates FP8 support into the high-performance FlashAttention v3 and FlashInfer attention mechanisms. This enhancement aims to improve the efficiency and scalability of large language model serving.

Highlights

  • FP8 KV Cache Quantization: This PR introduces support for FP8 (8-bit floating point) quantization of Key (K) and Value (V) caches. This can significantly reduce memory consumption for KV caches, potentially enabling longer context windows or larger batch sizes.
  • Offline Calibration Mechanism: A new offline calibration process is added, allowing users to collect activation statistics (absolute maximums) of KV caches during inference. These statistics are then used to compute optimal FP8 quantization scales, which can be exported to a JSON file and loaded for subsequent quantized inference runs.
  • Integration with Attention Backends: The FP8 KV cache functionality is deeply integrated into the FlashAttention v3 (FA3) and FlashInfer attention kernels. This ensures that the quantized KV data is correctly handled during both context (prefill) and token decoding phases, including the quantization of query (Q) tensors for FA3.
  • Flexible Scaling Options: The implementation supports both per-head and per-tensor quantization scaling for the KV caches, providing flexibility to choose the appropriate granularity for quantization based on model characteristics and performance requirements.
  • New Triton Kernels: Dedicated Triton kernels have been developed for efficient FP8 quantization of KV data during the copy operation to the memory buffer, and for per-head FP8 quantization of query tensors.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces FP8 quantization for KV cache, targeting improved memory efficiency. It includes mechanisms for both loading pre-calibrated FP8 scales and performing online calibration. The changes are well-integrated into the existing Llama model inference pipeline, with specific adaptations for FA3 and FlashInfer attention backends. New command-line arguments and environment variables provide control over the calibration process. Unit tests for the new FP8 attention kernels have been added.

Several areas for potential improvement or closer inspection include the strictness of assertions in the new FP8 tests, error message clarity in configuration loading, and ensuring robust handling of scale types (list vs. tensor) when interacting with different kernel backends. The increase in LIGHTLLM_TOKEN_MAX_BYTES should also be noted for its resource implications.

Comment on lines 358 to 367
offline_scales = infer_state.mem_manager.offline_fp8_quant_manager.scales
k_descale = (
offline_scales[self.layer_num_][: self.tp_k_head_num_].expand(infer_state.batch_size, self.tp_k_head_num_)
if offline_scales is not None
else ones_scales
)
v_descale = (
offline_scales[self.layer_num_][self.tp_k_head_num_ :].expand(infer_state.batch_size, self.tp_k_head_num_)
if offline_scales is not None
else ones_scales

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _context_attention_flashattention_fp8 (and similarly in _token_decode_attention_flashattention_fp8), the offline_scales are indexed like offline_scales[self.layer_num_][: self.tp_k_head_num_]. This assumes offline_scales[self.layer_num_] is a 1D tensor suitable for slicing.

If get_env_start_args().enable_fa3 is false (meaning per-tensor scaling was configured during OfflineFP8QuantManager initialization), offline_scales[self.layer_num_] would be a 0D (scalar) tensor. Slicing a scalar tensor like [: self.tp_k_head_num_] will result in a runtime error.

This suggests that the FA3 FP8 attention path implicitly requires per-head scales (i.e., enable_fa3 must have been true when scales were loaded/calibrated). If calibration_fp8kv mode can be active with FA3 while enable_fa3 (for scale configuration) is false, this will break.

Consider adding a check or ensuring that the scale format matches the expectation of this FA3-specific kernel. For example, assert that offline_scales[self.layer_num_].dim() == 1 or handle the scalar case appropriately (e.g., by broadcasting if FA3 can work with a single scale for all heads, though that defeats per-head scaling intent).

Comment on lines 227 to 229
offline_scales = infer_state.mem_manager.offline_fp8_quant_manager.scales_list
k_descale = offline_scales[self.layer_num_] if offline_scales is not None else None
v_descale = offline_scales[self.layer_num_] if offline_scales is not None else None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _context_attention_flashinfer_kernel_fp8 (and _token_decode_attention_flashinfer_fp8), offline_scales is infer_state.mem_manager.offline_fp8_quant_manager.scales_list. Then k_descale = offline_scales[self.layer_num_].

If per-head scaling was used (e.g., enable_fa3=True during scale loading/calibration), offline_scales[self.layer_num_] will be a Python list of numbers. The FlashInfer API for k_scale/v_scale expects None, a float, or a torch.Tensor. Passing a list directly might not work or lead to unexpected behavior.

If FlashInfer might be used with per-head scales from scales_list, ensure k_descale and v_descale are converted to torch.Tensor before being passed to infer_state.prefill_wrapper.run or infer_state.decode_wrapper.run.
Example conversion: k_descale = torch.tensor(offline_scales[self.layer_num_], device=q.device, dtype=torch.float32) if isinstance(offline_scales[self.layer_num_], list) else offline_scales[self.layer_num_] (adjust dtype/device as needed).

logger.warning("scales is None, no kv_quant_calibration_config_path be set")

def enable_calibration(self):
assert get_env_start_args().disable_cudagraph, "Calibration is not supported in cudagraph mode"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert get_env_start_args().disable_cudagraph, "Calibration is not supported in cudagraph mode" is a good runtime check. It might be beneficial to also log this information as a warning or info when calibration mode is entered, so users are aware of this limitation proactively if they happen to have cudagraph enabled elsewhere.

Comment on lines 103 to 106
if get_env_start_args().enable_fa3:
kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32)
else:
kv_max = kv_buffer.abs().amax(dim=()).to(torch.float32)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When get_env_start_args().enable_fa3 is true, kv_max is calculated as kv_buffer.abs().amax(dim=(0, 2)). If kv_buffer has a shape like (current_batch_seq_len, num_kv_heads, head_dim), then amax(dim=(0, 2)) results in a tensor of shape (num_kv_heads,). This seems correct for per-head quantization.

When enable_fa3 is false, kv_max is kv_buffer.abs().amax(dim=()), which is a scalar (per-tensor). This also seems correct.

This logic relies on kv_buffer passed to update_calibration_data having a consistent shape and representing the K and V values for which scales are being calibrated. Ensure that the buffer indeed corresponds to the structure expected by these amax operations (e.g., if it's per-head, it should contain data for all heads that need individual scaling factors).

Comment on lines +110 to +130
if dist.is_initialized() and dist.get_world_size() > 1:
if get_env_start_args().enable_fa3:
k_max, v_max = torch.chunk(self.abs_max, 2, dim=-1)
k_max = k_max.contiguous()
v_max = v_max.contiguous()
gathered_k_max = [torch.zeros_like(k_max) for _ in range(dist.get_world_size())]
gathered_v_max = [torch.zeros_like(v_max) for _ in range(dist.get_world_size())]
dist.all_gather(gathered_k_max, k_max, group=None, async_op=False)
dist.all_gather(gathered_v_max, v_max, group=None, async_op=False)
k_max = torch.cat(gathered_k_max, dim=-1)
v_max = torch.cat(gathered_v_max, dim=-1)
final_abs_max = torch.cat((k_max, v_max), dim=-1)
else:
dist.all_reduce(self.abs_max, op=dist.ReduceOp.MAX, group=None, async_op=False)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The distributed logic for aggregating abs_max values seems correct:

  • For FA3 (enable_fa3=True), it gathers K and V maxes separately across ranks and then concatenates them. This assumes self.abs_max on each rank stores [layer_num, 2 * head_num_per_rank] and head_num (passed to __init__) is the per-rank head count.
  • Otherwise (per-tensor scaling), it uses all_reduce with MAX.

This logic appears sound under the assumption that head_num provided to OfflineFP8QuantManager is the local head number for the current TP rank.

@@ -2,7 +2,7 @@
import ctypes
from typing import Tuple

LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 128))
LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 1024))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default value for LIGHTLLM_TOKEN_MAX_BYTES has been increased from 128 to 1024. This is a significant increase and will impact the shared memory allocation for the token queue. While this might be necessary for handling larger tokens (e.g., from multimodal inputs or complex special tokens), it's worth noting the resource implication. Ensure this change is intentional and tested for its impact on overall memory usage, especially in environments with many concurrent requests or limited shared memory.

Suggested change
LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 1024))
LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 1024)) # Increased from 128

# assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1)
cos_sim1 = F.cosine_similarity(o, o1).mean()
print(cos_sim1)
assert cos_sim1.item() == 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert cos_sim1.item() == 1 is extremely strict for comparing FP8 quantized operations against higher precision ones. FP8 quantization inherently introduces some loss of precision. Achieving perfect cosine similarity (1.0) usually implies either the test data is such that quantization errors are negligible/cancel out, or the reference path is also effectively operating at a similar precision for the comparison.

Consider using torch.allclose with appropriate atol and rtol values, or asserting that the cosine similarity is very close to 1 (e.g., assert cos_sim1.item() > 0.99). This would make the test more robust to minor, expected numerical differences. The commented-out torch.allclose line (147) suggests this was previously considered.

# assert torch.allclose(o, o1, atol=1e-2, rtol=2e-1)
cos_sim1 = F.cosine_similarity(o, o1).mean()
print(cos_sim1)
assert cos_sim1 == 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to other FP8 tests, assert cos_sim1 == 1 is a very strict assertion. FP8 quantization usually introduces some precision loss. It's recommended to use torch.allclose with suitable tolerances or check if cos_sim1 is very close to 1 (e.g., > 0.99) to account for expected numerical differences.

# assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1)
cos_sim1 = F.cosine_similarity(o, o1).mean()
print(cos_sim1)
assert cos_sim1 == 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert cos_sim1 == 1 is very strict for FP8 comparisons. Consider using torch.allclose with tolerances or checking if cos_sim1 > 0.99 to allow for minor, expected numerical deviations due to quantization.


cos_sim1 = F.cosine_similarity(o, o1).mean()
print(cos_sim1)
assert cos_sim1 == 1.0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion assert cos_sim1 == 1.0 is highly stringent for comparing FP8 quantized results with a reference. It's advisable to use torch.allclose with appropriate absolute and relative tolerances, or to assert that cos_sim1 is very close to 1.0 (e.g., cos_sim1.item() > 0.99), to account for the inherent precision differences with FP8 arithmetic.

@blueswhen blueswhen force-pushed the kvcache_fp8 branch 4 times, most recently from 92c274a to 3b46426 Compare June 23, 2025 07:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant