@@ -177,12 +177,238 @@ def merge_states(v: torch.Tensor, s: torch.Tensor) -> Tuple[torch.Tensor, torch.
177177 return _kernels .merge_states (v , s )
178178
179179
180+ class MultiLevelCascadeAttentionWrapper :
181+ r"""Attention wrapper for memory efficient multi-level cascade inference, this API assumes all
182+ levels KV-Cache are stored in a unified paged table.
183+
184+ Check :ref:`our tutorial<page-layout>` for page table layout, and
185+ `Cascade Inference Query/Output Layout <cascade-qo-indptr-layout>` for query/output layout.
186+
187+ The idea of cascade inference is introduced in our `blog post <https://flashinfer.ai/2024/02/02/cascade-inference.html>`_.
188+
189+ Example
190+ -------
191+ >>> import torch
192+ >>> import flashinfer
193+ >>> num_layers = 32
194+ >>> num_qo_heads = 64
195+ >>> num_kv_heads = 8
196+ >>> head_dim = 128
197+ >>> page_size = 16
198+ >>> # allocate 128MB workspace buffer
199+ >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
200+ >>> wrapper = flashinfer.MultiLevelCascadeAttentionWrapper(
201+ ... 2, workspace_buffer, "NHD"
202+ ... )
203+ >>> batch_size = 7
204+ >>> shared_kv_num_pages = 512
205+ >>> unique_kv_num_pages = 128
206+ >>> total_num_pages = shared_kv_num_pages + unique_kv_num_pages
207+ >>> shared_kv_page_indices = torch.arange(shared_kv_num_pages).int().to("cuda:0")
208+ >>> shared_kv_page_indptr = torch.tensor([0, shared_kv_num_pages], dtype=torch.int32, device="cuda:0")
209+ >>> unique_kv_page_indices = torch.arange(shared_kv_num_pages, total_num_pages).int().to("cuda:0")
210+ >>> unique_kv_page_indptr = torch.tensor(
211+ ... [0, 17, 29, 44, 48, 66, 100, 128], dtype=torch.int32, device="cuda:0"
212+ ... )
213+ >>> shared_kv_last_page_len = torch.tensor([page_size], dtype=torch.int32, device="cuda:0")
214+ >>> # 1 <= kv_last_page_len <= page_size
215+ >>> unique_kv_last_page_len = torch.tensor(
216+ ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0"
217+ ... )
218+ >>> kv_cache_at_layer = [
219+ ... torch.randn(
220+ ... total_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
221+ ... ) for _ in range(num_layers)
222+ ... ]
223+ >>> qo_indptr_arr = [
224+ ... torch.tensor([0, batch_size], dtype=torch.int32, device="cuda:0"), # top-level for shared KV-Cache
225+ ... torch.arange(batch_size + 1, dtype=torch.int32, device="cuda:0") # bottom-level for unique KV-Cache
226+ ... ]
227+ >>> # create auxiliary data structures for batch decode attention
228+ >>> wrapper.begin_forward(
229+ ... qo_indptr_arr,
230+ ... [shared_kv_page_indptr, unique_kv_page_indptr],
231+ ... [shared_kv_page_indices, unique_kv_page_indices],
232+ ... [shared_kv_last_page_len, unique_kv_last_page_len],
233+ ... num_qo_heads,
234+ ... num_kv_heads,
235+ ... head_dim,
236+ ... page_size,
237+ ... )
238+ >>> outputs = []
239+ >>> for i in range(num_layers):
240+ ... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0")
241+ ... # compute batch decode attention, reuse auxiliary data structures for all layers
242+ ... o = wrapper.forward(q, kv_cache_at_layer[i])
243+ ... outputs.append(o)
244+ ...
245+ >>> # clear auxiliary data structures
246+ >>> wrapper.end_forward()
247+ >>> outputs[0].shape
248+ torch.Size([7, 64, 128])
249+ """
250+
251+ def __init__ (
252+ self , num_levels , float_workspace_buffer : torch .Tensor , kv_layout : str = "NHD"
253+ ) -> None :
254+ r"""Constructor of :class:`MultiLevelCascadeAttentionWrapper`.
255+
256+ Parameters
257+ ----------
258+ num_levels : int
259+ The number of levels in the cascade attention.
260+ float_workspace_buffer : torch.Tensor
261+ The user reserved float workspace buffer used to store intermediate attention results
262+ in the split-k algorithm. The recommended size is 128MB, the device of the workspace
263+ buffer should be the same as the device of the input tensors.
264+ kv_layout : str
265+ The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
266+ """
267+ self ._batch_prefill_wrappers = [
268+ BatchPrefillWithPagedKVCacheWrapper (float_workspace_buffer , kv_layout )
269+ for _ in range (num_levels )
270+ ]
271+ self ._kv_layout = kv_layout
272+
273+ def reset_workspace_buffer (
274+ self ,
275+ float_workspace_buffer : torch .Tensor ,
276+ int_workspace_buffers : list [torch .Tensor ],
277+ ) -> None :
278+ r"""Reset the workspace buffer.
279+
280+ Parameters
281+ ----------
282+ float_workspace_buffer : torch.Tensor
283+ The new float workspace buffer, the device of the new float workspace buffer should
284+ be the same as the device of the input tensors.
285+
286+ int_workspace_buffer : torch.Tensor
287+ The new int workspace buffer, the device of the new int workspace buffer should
288+ be the same as the device of the input tensors.
289+ """
290+ for wrapper , int_workspace_buffer in zip (
291+ self ._batch_prefill_wrappers , int_workspace_buffers
292+ ):
293+ wrapper .reset_workspace_buffer (float_workspace_buffer , int_workspace_buffer )
294+
295+ def begin_forward (
296+ self ,
297+ qo_indptr_arr : list [torch .Tensor ],
298+ paged_kv_indptr_arr : list [torch .Tensor ],
299+ paged_kv_indices_arr : list [torch .Tensor ],
300+ paged_kv_last_page_len : list [torch .Tensor ],
301+ num_qo_heads : int ,
302+ num_kv_heads : int ,
303+ head_dim : int ,
304+ page_size : int ,
305+ ):
306+ r"""Create auxiliary data structures for multi-level cascade attention for multiple
307+ forward calls within the same decode step.
308+
309+ Parameters
310+ ----------
311+ qo_indptr_arr : list[torch.Tensor]
312+ An array of qo indptr tensors for each level, the array length should be equal to
313+ the number of levels. Check
314+ `Cascade Inference Query/Output Layout <cascade-qo-indptr-layout>` for query/output layout.
315+ The last element of each tensor should be the total number of queries/outputs.
316+ paged_kv_indptr_arr : list[torch.Tensor]
317+ An array of paged kv-cache indptr tensors for each level, the array length should be
318+ equal to the number of levels.
319+ paged_kv_indices_arr : list[torch.Tensor]
320+ An array of paged kv-cache indices tensors for each level, the array length should be
321+ equal to the number of levels.
322+ paged_kv_last_page_len : list[torch.Tensor]
323+ An array of paged kv-cache last page length tensors for each level, the array length
324+ should be equal to the number of levels.
325+ num_qo_heads : int
326+ The number of query/output heads.
327+ num_kv_heads : int
328+ The number of key/value heads.
329+ head_dim : int
330+ The dimension of the heads.
331+ page_size : int
332+ The page size of the paged kv-cache.
333+ """
334+ for (
335+ wrapper ,
336+ qo_indptr ,
337+ paged_kv_indptr ,
338+ paged_kv_indices ,
339+ paged_kv_last_page_len ,
340+ ) in zip (
341+ self ._batch_prefill_wrappers ,
342+ qo_indptr_arr ,
343+ paged_kv_indptr_arr ,
344+ paged_kv_indices_arr ,
345+ paged_kv_last_page_len ,
346+ ):
347+ wrapper .begin_forward (
348+ qo_indptr ,
349+ paged_kv_indptr ,
350+ paged_kv_indices ,
351+ paged_kv_last_page_len ,
352+ num_qo_heads ,
353+ num_kv_heads ,
354+ head_dim ,
355+ page_size ,
356+ )
357+
358+ def end_forward (self ):
359+ r"""Clear auxiliary data structures created by :meth:`begin_forward`."""
360+ for wrapper in self ._batch_prefill_wrappers :
361+ wrapper .end_forward ()
362+
363+ def forward (
364+ self ,
365+ q : torch .Tensor ,
366+ paged_kv_cache : torch .Tensor ,
367+ ** kwargs ,
368+ ):
369+ r"""Compute multi-level cascade attention.
370+
371+ Parameters
372+ ----------
373+ q : torch.Tensor
374+ The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``.
375+ paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
376+ The paged KV-Cache stored as a tuple of tensors or a single tensor:
377+
378+ * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape:
379+ ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``,
380+ and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``.
381+
382+ * a single 5-D tensor with shape:
383+ ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if
384+ :attr:`kv_layout` is ``NHD``, and
385+ ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if
386+ :attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and
387+ ``paged_kv_cache[:, 1]`` is the value-cache.
388+ """
389+ out , lse = self ._batch_prefill_wrappers [- 1 ].forward_return_lse (
390+ q , paged_kv_cache , ** kwargs
391+ )
392+ # NOTE(Zihao): causal mask should be False for all levels except the last level
393+ kwargs ["causal" ] = False
394+ for wrapper in self ._batch_prefill_wrappers [:- 1 ]:
395+ out_i , lse_i = wrapper .forward_return_lse (q , paged_kv_cache , ** kwargs )
396+ merge_state_in_place (out , lse , out_i , lse_i )
397+
398+ return out
399+
400+
180401class BatchDecodeWithSharedPrefixPagedKVCacheWrapper :
181402 r"""Wrapper class for decode attention with shared-prefix paged kv-cache for batch
182- of requests.
403+ of requests. The shared-prefix KV-Cache was stored in a standalone tensors, and the
404+ unique KV-Cache of each request was stored in a paged KV-Cache data stucture.
183405
184406 Check :ref:`our tutorial<page-layout>` for page table layout.
185407
408+ It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general
409+ multi-level cascade inference, where the KV-Cache of each level is stored in a unified
410+ page table. This API will be deprecated in the future.
411+
186412 Example
187413 -------
188414 >>> import torch
@@ -328,6 +554,11 @@ def begin_forward(
328554 The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
329555 is not equal to ``num_kv_heads``, the function will use
330556 `grouped query attention <https://arxiv.org/abs/2305.13245>`_.
557+
558+
559+ See Also
560+ --------
561+ MultiLevelCascadeAttentionWrapper
331562 """
332563 self ._batch_decode_wrapper .begin_forward (
333564 unique_kv_indptr ,
@@ -433,6 +664,10 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper:
433664
434665 Check :ref:`our tutorial<page-layout>` for paged kv-cache layout.
435666
667+ It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general
668+ multi-level cascade inference, where the KV-Cache of each level is stored in a unified
669+ page table. This API will be deprecated in the future.
670+
436671 Example
437672 -------
438673 >>> import torch
@@ -533,7 +768,7 @@ def __init__(
533768 self ._kv_layout = kv_layout
534769
535770 def reset_workspace_buffer (
536- self , float_workspace_buffer : torch .Tensor , int_workspace_buffer
771+ self , float_workspace_buffer : torch .Tensor , int_workspace_buffer : torch . Tensor
537772 ) -> None :
538773 r"""Reset the workspace buffer.
539774
@@ -671,6 +906,10 @@ def forward(
671906 -------
672907 V : torch.Tensor
673908 The attention output, shape: ``[qo_indptr[-1], num_heads, head_dim]``.
909+
910+ See Also
911+ --------
912+ MultiLevelCascadeAttentionWrapper
674913 """
675914 V_shared , S_shared = single_prefill_with_kv_cache_return_lse (
676915 q ,
0 commit comments