Skip to content

Commit 4bba6fa

Browse files
authored
feat: expose pytorch api for block sparse attention (#375)
The block sparse attention (for any block size (R, C)) are hidden in flashinfer's codebase but it was never exposed explicitly in python. As requested in #367 , this PR implements the PyTorch APIs for block sparse attention, accordingly to our experiments, it can greatly accelerate attention computation with low density (10x for Tree Attention in Sequoia).
1 parent b2d5994 commit 4bba6fa

File tree

6 files changed

+419
-7
lines changed

6 files changed

+419
-7
lines changed

docs/api/python/sparse.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
.. _apisparse:
2+
3+
flashinfer.sparse
4+
=================
5+
6+
Kernels for block sparse flashattention.
7+
8+
.. currentmodule:: flashinfer.sparse
9+
10+
.. autoclass:: BlockSparseAttentionWrapper
11+
:members:
12+
13+
.. automethod:: __init__

include/flashinfer/attention/prefill.cuh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -751,15 +751,16 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, uint32_t* v_smem_o
751751
*v_smem_offset_r -= 16 * num_frags_z * channel_size_128b_in;
752752
}
753753

754-
template <uint32_t num_frags_x, uint32_t num_frags_y>
755-
__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], float (*d)[2]) {
754+
template <uint32_t num_frags_x, uint32_t num_frags_y, typename DTypeQKAccum>
755+
__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], DTypeQKAccum (*m)[2],
756+
float (*d)[2]) {
756757
float d_rcp[num_frags_x][2];
757758
// compute reciprocal of d
758759
#pragma unroll
759760
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
760761
#pragma unroll
761762
for (uint32_t j = 0; j < 2; ++j) {
762-
d_rcp[fx][j] = math::ptx_rcp(d[fx][j]);
763+
d_rcp[fx][j] = (m[fx][j] != DTypeQKAccum(-5e4)) ? math::ptx_rcp(d[fx][j]) : 0.f;
763764
}
764765
}
765766

@@ -1161,7 +1162,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC
11611162
o_frag, (float*)smem, m, d, warp_idx, lane_idx);
11621163

11631164
// normalize d
1164-
normalize_d<num_frags_x, num_frags_y>(o_frag, d);
1165+
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);
11651166

11661167
// write back
11671168
write_o_reg_gmem<num_warps_x, num_warps_z, num_frags_x, num_frags_y>(
@@ -1428,7 +1429,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg
14281429
o_frag, (float*)smem, m, d, warp_idx, lane_idx);
14291430

14301431
// normalize d
1431-
normalize_d<num_frags_x, num_frags_y>(o_frag, d);
1432+
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);
14321433

14331434
const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size);
14341435

@@ -1719,7 +1720,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
17191720
o_frag, (float*)smem, m, d, warp_idx, lane_idx);
17201721

17211722
// normalize d
1722-
normalize_d<num_frags_x, num_frags_y>(o_frag, d);
1723+
normalize_d<num_frags_x, num_frags_y>(o_frag, m, d);
17231724

17241725
const uint32_t num_kv_chunks = ceil_div(kv_len, kv_chunk_size);
17251726

python/flashinfer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
BatchPrefillWithRaggedKVCacheWrapper,
2626
BatchPrefillWithPagedKVCacheWrapper,
2727
)
28+
from .sparse import BlockSparseAttentionWrapper
2829
from .cascade import (
2930
merge_state,
3031
merge_state_in_place,

python/flashinfer/sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def top_k_top_p_sampling_from_probs(
254254
>>> samples
255255
tensor([3, 3, 0, 1], device='cuda:0', dtype=torch.int32)
256256
>>> success
257-
tensor([True, True, True, True], device='cuda:0')
257+
tensor([True, True, True, True], device='cuda:0')
258258
259259
Notes
260260
-----

python/flashinfer/sparse.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import math
18+
from typing import Optional
19+
import torch
20+
import logging
21+
from .prefill import _compute_page_qk_indptr
22+
from .quantization import segment_packbits
23+
from .utils import (
24+
check_pos_encoding_mode,
25+
check_kv_layout,
26+
is_float8,
27+
expand_5d,
28+
PosEncodingMode,
29+
TensorLayout,
30+
)
31+
32+
try:
33+
from . import _kernels
34+
except ImportError as e:
35+
import os
36+
import logging
37+
38+
if os.environ.get("BUILD_DOC", "0") == "1":
39+
_kernels = None
40+
logging.warning("Kernels are not loaded in documentation build mode.")
41+
else:
42+
raise e
43+
44+
45+
class BlockSparseAttentionWrapper:
46+
def __init__(
47+
self,
48+
workspace_buffer: torch.Tensor,
49+
kv_layout: str = "NHD",
50+
):
51+
r"""Constructs of :class:`BlockSparseAttentionWrapper`.
52+
53+
Warning(Zihao): this is an experimental API and subject to change.
54+
55+
Parameters
56+
----------
57+
workspace_buffer : torch.Tensor
58+
The user reserved workspace buffer used to store auxiliary data structures,
59+
recommended size is 128MB, the device of the workspace buffer should be the
60+
same as the device of the input tensors.
61+
62+
kv_layout : str
63+
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
64+
"""
65+
check_kv_layout(kv_layout)
66+
self._kv_layout = kv_layout
67+
self._workspace_buffer = workspace_buffer
68+
self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper(
69+
TensorLayout[kv_layout].value,
70+
False, # use_cuda_graph
71+
)
72+
73+
def begin_forward(
74+
self,
75+
indptr: torch.Tensor,
76+
indices: torch.Tensor,
77+
M: int,
78+
N: int,
79+
R: int,
80+
C: int,
81+
num_qo_heads: int,
82+
num_kv_heads: int,
83+
head_dim: int,
84+
mask: Optional[torch.Tensor] = None,
85+
packed_mask: Optional[torch.Tensor] = None,
86+
q_data_type: str = "float16",
87+
):
88+
r"""Create auxiliary data structures for block sparse attention.
89+
90+
Parameters
91+
----------
92+
indptr : torch.Tensor
93+
The indptr of the block-sparse matrix, shape (MB + 1,), where MB is the number of blocks in the row dimension.
94+
indices: torch.Tensor
95+
The indices of the block-sparse matrix, shape (nnz,), where nnz is the number of non-zero blocks.
96+
M : int
97+
The number of rows of the block-sparse matrix, MB = ceil_div(M, R).
98+
N : int
99+
The number of columns of the block-sparse matrix, NB = ceil_div(N, C).
100+
R : int
101+
The number of rows in each block.
102+
C : int
103+
The number of columns in each block.
104+
num_qo_heads : int
105+
The number of heads in the query/output tensor.
106+
num_kv_heads : int
107+
The number of heads in the key/value tensor.
108+
head_dim : int
109+
The dimension of each head.
110+
mask : torch.Tensor, optional
111+
The flattened mask tensor, shape (nnz * R * C,), where nnz is the number of non-zero blocks.
112+
If every block is full, then we don't need to provide the mask tensor.
113+
packed_mask : torch.Tensor, optional
114+
The 1D packed mask tensor, if provided, the :attr:`custom_mask` will be ignored.
115+
The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`.
116+
q_data_type : str, optional
117+
The data type of the query tensor.
118+
119+
The :meth:`begin_forward` method should be called before any :meth:`forward` or
120+
:meth:`forward_return_lse` calls, auxiliary data structures will be created
121+
during this call and cached for multiple forward calls.
122+
123+
The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads``
124+
is not equal to ``num_kv_heads``, the function will use
125+
`grouped query attention <https://arxiv.org/abs/2305.13245>`_.
126+
"""
127+
num_rows = len(indptr) - 1
128+
qo_indptr_host = R * torch.arange(num_rows + 1, dtype=torch.int32)
129+
qo_indptr_host[-1] = M
130+
self._qo_indptr = qo_indptr_host.to(indptr.device)
131+
row_empty = indptr[1:] == indptr[:1]
132+
if indices.max().item() * C > N:
133+
raise ValueError("indices out of bound")
134+
last_block_pos = indices[torch.clamp(indptr[1:], min=1) - 1]
135+
last_block_pos.masked_fill_(row_empty, 0)
136+
last_block_len = torch.clamp(N - last_block_pos * C, max=C)
137+
138+
if mask is not None or packed_mask is not None:
139+
qk_indptr = _compute_page_qk_indptr(
140+
self._qo_indptr,
141+
indptr, # paged_kv_indptr
142+
last_block_len, # paged_kv_last_page_len
143+
C, # page_size
144+
)
145+
if packed_mask is None and mask is not None:
146+
# create packed mask from mask
147+
packed_mask, qk_indptr = segment_packbits(
148+
mask.contiguous().view(-1), qk_indptr, bitorder="little"
149+
)
150+
151+
self._paged_kv_indptr_buf = indptr
152+
self._paged_kv_indices_buf = indices
153+
self._paged_kv_last_page_len = last_block_len
154+
if packed_mask is not None:
155+
self._packed_mask_buf = packed_mask
156+
self._qk_indptr_buf = qk_indptr
157+
else:
158+
self._packed_mask_buf = None
159+
160+
empty_q_data = torch.empty(
161+
0,
162+
dtype=(
163+
getattr(torch, q_data_type)
164+
if isinstance(q_data_type, str)
165+
else q_data_type
166+
),
167+
)
168+
169+
self._wrapper.begin_forward(
170+
self._workspace_buffer,
171+
self._qo_indptr,
172+
self._paged_kv_indptr_buf,
173+
num_rows,
174+
num_qo_heads,
175+
num_kv_heads,
176+
head_dim,
177+
C,
178+
empty_q_data,
179+
)
180+
181+
def end_forward(self):
182+
r"""Clear the auxiliary data structures created by :meth:`begin_forward`."""
183+
self._qo_indptr = None
184+
self._paged_kv_indptr_buf = None
185+
self._paged_kv_indices_buf = None
186+
self._paged_kv_last_page_len = None
187+
self._packed_mask_buf = None
188+
self._qk_indptr_buf = None
189+
190+
def forward(
191+
self,
192+
q: torch.Tensor,
193+
kv_data: torch.Tensor,
194+
pos_encoding_mode: str = "NONE",
195+
allow_fp16_qk_reduction: bool = False,
196+
logits_soft_cap: Optional[float] = None,
197+
sm_scale: Optional[float] = None,
198+
rope_scale: Optional[float] = None,
199+
rope_theta: Optional[float] = None,
200+
):
201+
r"""Compute block-sparse attention between Q/K/V tensors.
202+
203+
Warning(Zihao): in the next release, kv_data will be decoupled into standalone k/v tensors, each
204+
with shape (N, num_kv_heads, head_dim).
205+
206+
Parameters
207+
----------
208+
q : torch.Tensor
209+
The query tensor, shape (M, num_qo_heads, head_dim).
210+
kv_data : torch.Tensor
211+
The key/value tensor, shape (N // C, 2, C, num_kv_heads, head_dim).
212+
pos_encoding_mode : str, optional
213+
The position encoding applied inside attention kernels, could be
214+
``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``.
215+
Default is ``NONE``.
216+
allow_fp16_qk_reduction : bool
217+
Whether to use f16 for qk reduction (faster at the cost of slight precision
218+
loss).
219+
logits_soft_cap : Optional[float]
220+
The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not
221+
provided, will be set to ``0``. If greater than 0, the logits will be capped according to
222+
formula:
223+
:math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`,
224+
where :math:`x` is the input logits.
225+
sm_scale : Optional[float]
226+
The scale used in softmax, if not provided, will be set to
227+
``1.0 / sqrt(head_dim)``.
228+
rope_scale : Optional[float]
229+
The scale used in RoPE interpolation, if not provided, will be set to
230+
``1.0``.
231+
rope_theta : Optional[float]
232+
The theta used in RoPE, if not provided, will be set to ``1e4``.
233+
234+
Returns
235+
-------
236+
torch.Tensor
237+
The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``.
238+
"""
239+
check_pos_encoding_mode(pos_encoding_mode)
240+
if logits_soft_cap is None:
241+
logits_soft_cap = 0.0
242+
if sm_scale is None:
243+
sm_scale = 1.0 / math.sqrt(q.size(-1))
244+
if rope_scale is None:
245+
rope_scale = 1.0
246+
if rope_theta is None:
247+
rope_theta = 1e4
248+
if is_float8(q):
249+
logging.warning(
250+
"Our current prefill kernel implementation needs f16 input, the f8 inputs "
251+
" are casted to f16, which could result in performance degradation."
252+
)
253+
q = q.to(torch.float16)
254+
kv_data = kv_data.to(torch.float16)
255+
256+
kv_data = expand_5d(kv_data, self._kv_layout)
257+
258+
if self._packed_mask_buf is None:
259+
return self._wrapper.forward(
260+
q,
261+
self._qo_indptr,
262+
kv_data,
263+
self._paged_kv_indptr_buf,
264+
self._paged_kv_indices_buf,
265+
self._paged_kv_last_page_len,
266+
False, # causal
267+
PosEncodingMode[pos_encoding_mode].value,
268+
allow_fp16_qk_reduction,
269+
logits_soft_cap,
270+
sm_scale,
271+
rope_scale,
272+
rope_theta,
273+
False, # return LSE
274+
)[0]
275+
else:
276+
return self._wrapper.forward_custom_mask(
277+
q,
278+
self._qo_indptr,
279+
kv_data,
280+
self._paged_kv_indptr_buf,
281+
self._paged_kv_indices_buf,
282+
self._paged_kv_last_page_len,
283+
self._packed_mask_buf,
284+
self._qk_indptr_buf,
285+
PosEncodingMode[pos_encoding_mode].value,
286+
allow_fp16_qk_reduction,
287+
logits_soft_cap,
288+
sm_scale,
289+
rope_scale,
290+
rope_theta,
291+
False, # return LSE
292+
)[0]

0 commit comments

Comments
 (0)