Skip to content

Commit 1918711

Browse files
committed
fix comment
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
1 parent 6248535 commit 1918711

File tree

5 files changed

+11
-12
lines changed

5 files changed

+11
-12
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ steps:
664664
# Attention
665665
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
666666
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
667-
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py
667+
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
668668
- pytest -v -s tests/kernels/test_cutlass_mla_decode.py
669669
# Quantization
670670
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'

benchmarks/kernels/benchmark_trtllm_decode_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def benchmark_decode(
4141
device = "cuda"
4242
torch.manual_seed(0)
4343

44-
# Currently only HEAD_GRP_SIZE == 8 is supported
4544
HEAD_GRP_SIZE = 8
4645
MAX_SEQ_LEN = max_seq_len
4746

benchmarks/kernels/benchmark_trtllm_prefill_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def benchmark_prefill(
4040
torch.set_default_device("cuda")
4141
torch.manual_seed(0)
4242

43-
# Currently only HEAD_GRP_SIZE == 8 is supported
4443
HEAD_GRP_SIZE = 8
4544
MAX_SEQ_LEN = max_seq_len
4645

vllm/utils/flashinfer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def use_trtllm_attention(
159159

160160
# Check if the dimensions are supported by TRTLLM decode attention
161161
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
162-
or num_qo_heads // num_kv_heads > 8
163162
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
164163
return False
165164

vllm/v1/attention/backends/flashinfer.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -523,14 +523,16 @@ def build(self,
523523
head_dim = self.kv_cache_spec.head_size
524524

525525
# currently prefill trtllm attention does not support fp8 kv cache
526-
prefill_use_trtllm = not cache_dtype.startswith(
527-
"fp8") and use_trtllm_attention(num_prefill_tokens, max_seq_len,
528-
cache_dtype, num_qo_heads,
529-
num_kv_heads, head_dim)
530-
decode_use_trtllm = use_trtllm_attention(num_decode_tokens,
531-
max_seq_len, cache_dtype,
532-
num_qo_heads, num_kv_heads,
533-
head_dim)
526+
# trtllm may not support sliding window
527+
prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
528+
and not cache_dtype.startswith("fp8")
529+
and use_trtllm_attention(
530+
num_prefill_tokens, max_seq_len, cache_dtype,
531+
num_qo_heads, num_kv_heads, head_dim))
532+
decode_use_trtllm = (self.global_hyperparameters.window_left == -1
533+
and use_trtllm_attention(
534+
num_decode_tokens, max_seq_len, cache_dtype,
535+
num_qo_heads, num_kv_heads, head_dim))
534536

535537
attn_metadata = FlashInferMetadata(
536538
num_actual_tokens=num_actual_tokens,

0 commit comments

Comments
 (0)