@@ -281,10 +281,22 @@ class MultiLevelCascadeAttentionWrapper:
281281 ...
282282 >>> outputs[0].shape
283283 torch.Size([7, 64, 128])
284+
285+ See Also
286+ --------
287+ BatchPrefillWithPagedKVCacheWrapper
284288 """
285289
286290 def __init__ (
287- self , num_levels , float_workspace_buffer : torch .Tensor , kv_layout : str = "NHD"
291+ self ,
292+ num_levels ,
293+ float_workspace_buffer : torch .Tensor ,
294+ kv_layout : str = "NHD" ,
295+ use_cuda_graph : bool = False ,
296+ qo_indptr_buf_arr : Optional [list [torch .Tensor ]] = None ,
297+ paged_kv_indptr_buf_arr : Optional [list [torch .Tensor ]] = None ,
298+ paged_kv_indices_buf_arr : Optional [list [torch .Tensor ]] = None ,
299+ paged_kv_last_page_len_buf_arr : Optional [list [torch .Tensor ]] = None ,
288300 ) -> None :
289301 r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`.
290302
@@ -298,14 +310,59 @@ def __init__(
298310 buffer should be the same as the device of the input tensors.
299311 kv_layout : str
300312 The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
313+ use_cuda_graph : bool
314+ Whether to use CUDA graph to capture the kernels, if enabled, the auxiliary data structures
315+ will be stored in provided buffers.
316+ qo_indptr_buf_arr : Optional[List[torch.Tensor]]
317+ An array of qo indptr buffers for each level, the array length should be equal to
318+ the number of levels.
319+ The last element of each tensor should be the total number of queries/outputs.
320+ paged_kv_indptr_buf_arr : Optional[List[torch.Tensor]]
321+ An array of paged kv-cache indptr buffers for each level, the array length should be
322+ equal to the number of levels.
323+ paged_kv_indices_buf_arr : Optional[List[torch.Tensor]]
324+ An array of paged kv-cache indices buffers for each level, the array length should be
325+ equal to the number of levels.
326+ paged_kv_last_page_len_buf_arr : Optional[List[torch.Tensor]]
327+ An array of paged kv-cache last page length buffers for each level, the array length
328+ should be equal to the number of levels.
301329 """
302- self ._batch_prefill_wrappers = [
303- BatchPrefillWithPagedKVCacheWrapper (float_workspace_buffer , kv_layout )
304- for _ in range (num_levels )
305- ]
330+ self ._use_cuda_graph = use_cuda_graph
331+ if use_cuda_graph :
332+ self ._batch_prefill_wrappers = [
333+ BatchPrefillWithPagedKVCacheWrapper (
334+ float_workspace_buffer ,
335+ kv_layout ,
336+ use_cuda_graph = True ,
337+ qo_indptr_buf = qo_indptr_buf ,
338+ paged_kv_indptr_buf = paged_kv_indptr_buf ,
339+ paged_kv_indices_buf = paged_kv_indices_buf ,
340+ paged_kv_last_page_len_buf = paged_kv_last_page_len_buf ,
341+ )
342+ for (
343+ qo_indptr_buf ,
344+ paged_kv_indptr_buf ,
345+ paged_kv_indices_buf ,
346+ paged_kv_last_page_len_buf ,
347+ ) in zip (
348+ qo_indptr_buf_arr ,
349+ paged_kv_indptr_buf_arr ,
350+ paged_kv_indices_buf_arr ,
351+ paged_kv_last_page_len_buf_arr ,
352+ )
353+ ]
354+ else :
355+ self ._batch_prefill_wrappers = [
356+ BatchPrefillWithPagedKVCacheWrapper (float_workspace_buffer , kv_layout )
357+ for _ in range (num_levels )
358+ ]
306359 self ._num_levels = num_levels
307360 self ._kv_layout = kv_layout
308361
362+ @property
363+ def is_cuda_graph_enabled (self ) -> bool :
364+ return self ._use_cuda_graph
365+
309366 def reset_workspace_buffer (
310367 self ,
311368 float_workspace_buffer : torch .Tensor ,
@@ -912,7 +969,7 @@ def forward(
912969 k_shared : torch .Tensor ,
913970 v_shared : torch .Tensor ,
914971 unique_kv_cache : torch .Tensor ,
915- causal : bool = True ,
972+ causal : bool = False ,
916973 allow_fp16_qk_reduction : bool = False ,
917974 sm_scale : Optional [float ] = None ,
918975 rope_scale : Optional [float ] = None ,
0 commit comments