@@ -224,6 +224,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
224
224
query_start_loc = query_start_loc_host ,
225
225
device = self .runner .device ,
226
226
data_type = kv_cache_dtype ,
227
+ q_data_type = self .runner .model_config .dtype ,
227
228
use_cuda_graph = True ,
228
229
decode_wrapper = self ._graph_decode_wrapper ,
229
230
prefill_wrapper = None )
@@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata):
292
293
page_size : Optional [int ] = None
293
294
# The data type of the paged kv cache
294
295
data_type : torch .dtype = None
296
+ # The data type of the query
297
+ q_data_type : torch .dtype = None
295
298
device : torch .device = torch .device ("cuda" )
296
299
is_profile_run : bool = False
297
300
@@ -353,7 +356,10 @@ def begin_forward(self):
353
356
self .page_size ,
354
357
# Disable flashinfer's pos encoding and use vllm's rope.
355
358
pos_encoding_mode = "NONE" ,
356
- data_type = self .data_type )
359
+ # kv-cache data type.
360
+ data_type = self .data_type ,
361
+ # query data type.
362
+ q_data_type = self .q_data_type )
357
363
358
364
def asdict_zerocopy (self ,
359
365
skip_fields : Optional [Set [str ]] = None
@@ -617,6 +623,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
617
623
query_start_loc = query_start_loc ,
618
624
device = device ,
619
625
data_type = kv_cache_dtype ,
626
+ q_data_type = self .runner .model_config .dtype ,
620
627
use_cuda_graph = use_captured_graph ,
621
628
is_profile_run = self .is_profile_run )
622
629
0 commit comments