Skip to content

Commit a3bc936

Browse files
committed
[Core/Bugfix] Add query dtype as per FlashInfer API requirements.
1 parent e01c2be commit a3bc936

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

tests/kernels/test_flashinfer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,8 @@ def test_flashinfer_decode_with_paged_fp8_kv(
445445
head_size,
446446
block_size,
447447
"NONE",
448-
data_type=dtype)
448+
data_type=dtype,
449+
q_data_type=dtype)
449450
output = wrapper.forward(query,
450451
kv_cache_fp8,
451452
logits_soft_cap=soft_cap,

vllm/attention/backends/flashinfer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
224224
query_start_loc=query_start_loc_host,
225225
device=self.runner.device,
226226
data_type=kv_cache_dtype,
227+
q_data_type=self.runner.model_config.dtype,
227228
use_cuda_graph=True,
228229
decode_wrapper=self._graph_decode_wrapper,
229230
prefill_wrapper=None)
@@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata):
292293
page_size: Optional[int] = None
293294
# The data type of the paged kv cache
294295
data_type: torch.dtype = None
296+
# The data type of the query
297+
q_data_type: torch.dtype = None
295298
device: torch.device = torch.device("cuda")
296299
is_profile_run: bool = False
297300

@@ -353,7 +356,10 @@ def begin_forward(self):
353356
self.page_size,
354357
# Disable flashinfer's pos encoding and use vllm's rope.
355358
pos_encoding_mode="NONE",
356-
data_type=self.data_type)
359+
# kv-cache data type.
360+
data_type=self.data_type,
361+
# query data type.
362+
q_data_type=self.q_data_type)
357363

358364
def asdict_zerocopy(self,
359365
skip_fields: Optional[Set[str]] = None
@@ -617,6 +623,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
617623
query_start_loc=query_start_loc,
618624
device=device,
619625
data_type=kv_cache_dtype,
626+
q_data_type=self.runner.model_config.dtype,
620627
use_cuda_graph=use_captured_graph,
621628
is_profile_run=self.is_profile_run)
622629

0 commit comments

Comments
 (0)