Skip to content

Commit 1fa8007

Browse files
committed
fix fp8 kv cache
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
1 parent e8cc53a commit 1fa8007

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,9 @@ def _plan(self, attn_metadata: FlashInferMetadata):
406406
attn_metadata.decode_wrapper = self._get_decode_wrapper()
407407
if not FlashInferBackend.use_trtllm_decode_attention(
408408
self._num_decodes, attn_metadata.max_seq_len,
409-
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
410-
attn_metadata.num_kv_heads, attn_metadata.head_dim):
409+
self.runner.cache_config.cache_dtype,
410+
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
411+
attn_metadata.head_dim):
411412
attn_metadata.decode_wrapper.plan(
412413
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
413414
attn_metadata.paged_kv_indices,
@@ -594,10 +595,10 @@ def forward(
594595
query: shape = [num_tokens, num_heads, head_size]
595596
key: shape = [num_tokens, num_kv_heads, head_size]
596597
value: shape = [num_tokens, num_kv_heads, head_size]
597-
kv_cache: shape -
598+
kv_cache: shape -
598599
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
599600
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
600-
601+
601602
602603
attn_metadata: Metadata for attention.
603604
Returns:

0 commit comments

Comments
 (0)