@@ -406,8 +406,9 @@ def _plan(self, attn_metadata: FlashInferMetadata):
406406 attn_metadata .decode_wrapper = self ._get_decode_wrapper ()
407407 if not FlashInferBackend .use_trtllm_decode_attention (
408408 self ._num_decodes , attn_metadata .max_seq_len ,
409- attn_metadata .kv_data_type , attn_metadata .num_qo_heads ,
410- attn_metadata .num_kv_heads , attn_metadata .head_dim ):
409+ self .runner .cache_config .cache_dtype ,
410+ attn_metadata .num_qo_heads , attn_metadata .num_kv_heads ,
411+ attn_metadata .head_dim ):
411412 attn_metadata .decode_wrapper .plan (
412413 attn_metadata .paged_kv_indptr [:self ._num_decodes + 1 ],
413414 attn_metadata .paged_kv_indices ,
@@ -594,10 +595,10 @@ def forward(
594595 query: shape = [num_tokens, num_heads, head_size]
595596 key: shape = [num_tokens, num_kv_heads, head_size]
596597 value: shape = [num_tokens, num_kv_heads, head_size]
597- kv_cache: shape -
598+ kv_cache: shape -
598599 # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
599600 # HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
600-
601+
601602
602603 attn_metadata: Metadata for attention.
603604 Returns:
0 commit comments