Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/kernels/benchmark_trtllm_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def write_results_to_csv(results, filename=None):
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
Expand Down
1 change: 1 addition & 0 deletions benchmarks/kernels/benchmark_trtllm_prefill_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def write_results_to_csv(results, filename=None):
quant_dtypes = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(FP8_DTYPE, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
Expand Down
12 changes: 12 additions & 0 deletions tests/kernels/attention/test_flashinfer_trtllm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
Expand All @@ -44,6 +45,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
HEAD_SIZE = [128]
KV_LAYOUT = ["HND"] # currently only HND is supported
BLOCK_SIZE = [16]
WINDOW_LEFT = [-1, 127]
SOFT_CAP = [None, 50.0]

NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
Expand All @@ -57,6 +59,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
Expand All @@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
head_size: int,
kv_layout: str,
block_size: int,
window_left: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
Expand Down Expand Up @@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap)

output = torch.empty(ref_query.shape, dtype=dtype)
Expand Down Expand Up @@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
window_left=window_left,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
Expand Down Expand Up @@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
@pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
@pytest.mark.parametrize("soft_cap", [None])
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
Expand All @@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
head_size: int,
kv_layout: str,
block_size: int,
window_left: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
Expand Down Expand Up @@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap)

output = torch.empty(ref_query.shape, dtype=dtype)
Expand Down Expand Up @@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
window_left=window_left,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
Expand All @@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline(
rtol, atol = 4e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
rtol, atol = 4e-2, 6e-2
else:
rtol, atol = 1e-2, 1e-2

Expand Down
6 changes: 4 additions & 2 deletions vllm/compilation/fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,10 @@ def __init__(self, config: VllmConfig):
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
pattern_fp8.register_if_supported(self.patterns)

pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
pattern_nvfp4.register_if_supported(self.patterns)
if current_platform.is_cuda() and hasattr(torch.ops._C,
"scaled_fp4_quant"):
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
pattern_nvfp4.register_if_supported(self.patterns)

if len(attn_layers) == 0:
logger.warning(
Expand Down
13 changes: 4 additions & 9 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,19 +194,15 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
FlashInferBackend.validate_head_size(self.head_dim)
self.page_size = self.kv_cache_spec.block_size

self.enable_fusion = (
self.compilation_config.pass_config.enable_attn_fusion)
self.q_data_type = self.model_config.dtype
self.cache_dtype = self.cache_config.cache_dtype
if self.cache_dtype.startswith("fp8"):
self.kv_cache_dtype = (
FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.cache_dtype))
# Insert FP8 quant for query if FP8 kv cache and attn fusion enabled
if self.enable_fusion:
self.q_data_type = self.kv_cache_dtype
else:
assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype
self.q_data_type = self.kv_cache_dtype

self._cascade_wrapper = None # Wrapper for cascade attention

Expand Down Expand Up @@ -668,8 +664,6 @@ def forward(

# The attn+quant fusion happens when output_scale is provided.
if output_scale is None:
assert attn_metadata.q_data_type != FP8_DTYPE, \
"Query can only be FP8 if output fusion happened."
assert output_block_scale is None, "output_block_scale "\
"is not supported when fusion has not happened"
else:
Expand Down Expand Up @@ -697,7 +691,8 @@ def forward(
elif output.dtype == FP4_DTYPE:
self.o_sf_scale = layer._o_scale_float

# Insert FP8 quant for query
# Insert FP8 quant for query
if attn_metadata.q_data_type == FP8_DTYPE:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
Expand Down