| 
11 | 11 |     from vllm.attention.backends.rocm_flash_attn import (  | 
12 | 12 |         ROCmFlashAttentionMetadata as FlashAttentionMetadata)  | 
13 | 13 | 
 
  | 
 | 14 | +try:  | 
 | 15 | +    from flashinfer import BatchDecodeWithPagedKVCacheWrapper  | 
 | 16 | +    from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper  | 
 | 17 | +    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper  | 
 | 18 | +    FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024  | 
 | 19 | +except ImportError:  | 
 | 20 | +    BatchDecodeWithPagedKVCacheWrapper = None  | 
 | 21 | +    CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None  | 
 | 22 | +    BatchPrefillWithPagedKVCacheWrapper = None  | 
 | 23 | +    FLASHINFER_WORKSPACE_BUFFER_SIZE = 0  | 
 | 24 | + | 
14 | 25 | from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,  | 
15 | 26 |                          ModelConfig, MultiModalConfig, ParallelConfig,  | 
16 | 27 |                          PromptAdapterConfig, SchedulerConfig)  | 
@@ -79,6 +90,11 @@ def __init__(  | 
79 | 90 |             return_hidden_states=return_hidden_states,  | 
80 | 91 |         )  | 
81 | 92 | 
 
  | 
 | 93 | +        self.flashinfer_decode_workspace_buffer = None  | 
 | 94 | +        self.flashinfer_decode_wrapper = None  | 
 | 95 | +        self.flashinfer_prefill_workspace_buffer = None  | 
 | 96 | +        self.flashinfer_prefill_wrapper = None  | 
 | 97 | + | 
82 | 98 |     def _update_flash_attn_metadata(self, attn_metadata, num_seqs,  | 
83 | 99 |                                     num_queries):  | 
84 | 100 |         assert isinstance(attn_metadata, FlashAttentionMetadata)  | 
@@ -286,6 +302,37 @@ def execute_model(  | 
286 | 302 |                     model_input.prompt_adapter_requests,  | 
287 | 303 |                     model_input.prompt_adapter_mapping)  | 
288 | 304 | 
 
  | 
 | 305 | +            if self.attn_backend.get_name() == "flashinfer":  | 
 | 306 | +                assert model_input.attn_metadata is not None  | 
 | 307 | +                assert model_input.input_tokens is not None  | 
 | 308 | +                if self.flashinfer_decode_workspace_buffer is None:  | 
 | 309 | +                    self.flashinfer_decode_workspace_buffer = torch.empty(  | 
 | 310 | +                        FLASHINFER_WORKSPACE_BUFFER_SIZE,  | 
 | 311 | +                        dtype=torch.uint8,  | 
 | 312 | +                        device=self.device)  | 
 | 313 | +                    self.flashinfer_decode_wrapper = \  | 
 | 314 | +                        BatchDecodeWithPagedKVCacheWrapper(  | 
 | 315 | +                        self.flashinfer_decode_workspace_buffer, "NHD")  | 
 | 316 | +                    self.flashinfer_prefill_workspace_buffer = torch.empty(  | 
 | 317 | +                        FLASHINFER_WORKSPACE_BUFFER_SIZE,  | 
 | 318 | +                        dtype=torch.uint8,  | 
 | 319 | +                        device=self.device)  | 
 | 320 | +                    self.flashinfer_prefill_wrapper = \  | 
 | 321 | +                        BatchPrefillWithPagedKVCacheWrapper(  | 
 | 322 | +                        self.flashinfer_prefill_workspace_buffer, "NHD")  | 
 | 323 | + | 
 | 324 | +                model_input.attn_metadata.prefill_wrapper = \  | 
 | 325 | +                    self.flashinfer_prefill_wrapper  | 
 | 326 | +                if model_input.attn_metadata.use_cuda_graph:  | 
 | 327 | +                    batch_size = model_input.input_tokens.shape[0]  | 
 | 328 | +                    model_input.attn_metadata.decode_wrapper = \  | 
 | 329 | +                        self.graph_runners[model_input.  | 
 | 330 | +                        virtual_engine][batch_size].flashinfer_decode_wrapper  | 
 | 331 | +                else:  | 
 | 332 | +                    model_input.attn_metadata.decode_wrapper = \  | 
 | 333 | +                        self.flashinfer_decode_wrapper  | 
 | 334 | +                model_input.attn_metadata.begin_forward()  | 
 | 335 | + | 
289 | 336 |         # Detect exec mode  | 
290 | 337 |         assert model_input.attn_metadata is not None  | 
291 | 338 |         use_cuda_graph = False  | 
 | 
0 commit comments