Skip to content

Commit 6e487e7

Browse files
[kernel/fix] Performance Optimization for Decoding Kernel and Benchmarking (#5274)
* prevent re-creating intermediate tensors * add singleton class holding intermediate values * fix triton kernel api * add benchmark in pytest * fix kernel api and add benchmark * revise flash decoding triton kernel in/out shapes * fix calling of triton kernel in modeling * fix pytest: extract to util functions
1 parent 9e2342b commit 6e487e7

File tree

7 files changed

+382
-152
lines changed

7 files changed

+382
-152
lines changed

colossalai/inference/modeling/models/llama.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from colossalai.inference.modeling.layers.attention import PagedAttention
88
from colossalai.inference.struct import BatchInfo
9-
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_fwd
9+
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention
1010
from colossalai.logging import get_dist_logger
1111

1212
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
@@ -209,7 +209,15 @@ def llama_attn_forward(
209209
if HAS_TRITON:
210210
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
211211
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
212-
attn_output = flash_decoding_fwd(query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
212+
# TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel
213+
# in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output
214+
# should be revised, as we could see in previous part of `llama_attn_forward` we still have
215+
# redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent.
216+
query_states = query_states.transpose(1, 2)
217+
attn_output = flash_decoding_attention(
218+
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
219+
)
220+
attn_output = attn_output.squeeze(1)
213221
else:
214222
attn_output = PagedAttention.pad_decoding_forward(
215223
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask

colossalai/kernel/triton/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
# There may exist import error even if we have triton installed.
1010
if HAS_TRITON:
1111
from .context_attn_unpad import context_attention_unpadded
12-
from .flash_decoding import flash_decoding_fwd
12+
from .flash_decoding import flash_decoding_attention
13+
from .flash_decoding_utils import FDIntermTensors
14+
1315
from .rms_layernorm import rms_layernorm
1416
from .gptq_triton import gptq_fused_linear_triton
1517
from .kvcache_copy import copy_kv_to_blocked_cache
@@ -18,10 +20,11 @@
1820

1921
__all__ = [
2022
"context_attention_unpadded",
21-
"flash_decoding_fwd",
23+
"flash_decoding_attention",
2224
"copy_kv_to_blocked_cache",
2325
"softmax",
2426
"rms_layernorm",
2527
"gptq_fused_linear_triton",
2628
"rotary_embedding",
29+
"FDIntermTensors",
2730
]

colossalai/kernel/triton/flash_decoding.py

Lines changed: 81 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99
# Triton 2.1.0
1010
@triton.jit
1111
def _flash_decoding_fwd_kernel(
12-
Q, # [batch_size, head_num, head_dim]
12+
Q, # [batch_size, head_num, q_len(1), head_dim]
1313
KCache, # [num_blocks, num_kv_heads, head_dim, block_size]
1414
VCache, # [num_blocks, num_kv_heads, head_dim, block_size]
1515
block_tables, # [batch_size, max_blocks_per_sequence]
1616
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
1717
mid_o_lse, # [batch_size, head_num, kv_split_num]
18-
context_lengths, # [batch_size]
18+
kv_seq_len, # [batch_size]
1919
stride_qt,
2020
stride_qh,
21+
stride_ql,
2122
stride_qd,
2223
stride_cacheb,
2324
stride_cacheh,
@@ -51,7 +52,7 @@ def _flash_decoding_fwd_kernel(
5152
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
5253

5354
# get the current (kv) sequence length from provided context lengths tensor
54-
cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
55+
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
5556

5657
offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
5758
q = tl.load(Q + offsets_q)
@@ -65,7 +66,6 @@ def _flash_decoding_fwd_kernel(
6566
cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
6667

6768
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
68-
# TODO might want to remove if-else block?
6969
return
7070

7171
cur_occupied_size = tl.where(
@@ -132,7 +132,7 @@ def _flash_decoding_fwd_reduce_kernel(
132132
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
133133
mid_o_lse, # [batch_size, head_num, kv_split_num]
134134
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
135-
context_lengths,
135+
kv_seq_len,
136136
stride_mid_ot,
137137
stride_mid_oh,
138138
stride_mid_ob,
@@ -141,6 +141,7 @@ def _flash_decoding_fwd_reduce_kernel(
141141
stride_o_lseh,
142142
stride_o_lseb,
143143
stride_ob,
144+
stride_ol,
144145
stride_oh,
145146
stride_od,
146147
BLOCK_KV: tl.constexpr,
@@ -149,7 +150,7 @@ def _flash_decoding_fwd_reduce_kernel(
149150
cur_seq_idx = tl.program_id(0)
150151
cur_head_idx = tl.program_id(1)
151152

152-
cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
153+
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
153154
offsets_dmodel = tl.arange(0, HEAD_DIM)
154155

155156
# NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have
@@ -181,97 +182,126 @@ def _flash_decoding_fwd_reduce_kernel(
181182

182183
# Decoding Stage
183184
# Used with blocked KV Cache (PagedAttention)
184-
def flash_decoding_fwd(
185-
q: torch.Tensor, # [bsz(e.g.num_tokens), 1, num_heads, head_dim]
186-
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
187-
v_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size]
188-
context_lengths: torch.Tensor, # [batch_size]
189-
block_tables: torch.Tensor, # [batch_size, max_blocks_per_sequence]
185+
def flash_decoding_attention(
186+
q: torch.Tensor,
187+
k_cache: torch.Tensor,
188+
v_cache: torch.Tensor,
189+
kv_seq_len: torch.Tensor,
190+
block_tables: torch.Tensor,
190191
block_size: int,
191-
num_kv_group: int = 1,
192+
max_seq_len_in_batch: int = None,
193+
mid_output: torch.Tensor = None,
194+
mid_output_lse: torch.Tensor = None,
195+
sm_scale: int = None,
196+
kv_group_num: int = 1,
192197
):
193-
bsz, _, num_heads, head_dim = q.shape
198+
"""
199+
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
200+
201+
Args:
202+
q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim]
203+
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
204+
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
205+
kv_seq_len (torch.Tensor): [batch_size]
206+
records the (kv) sequence lengths incorporating past kv sequence lengths.
207+
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
208+
max_seq_len_in_batch (int): Maximum sequence length in the batch.
209+
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
210+
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
211+
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
212+
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
213+
block_size (int): Size of each block in the blocked key/value cache.
214+
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
215+
216+
Returns:
217+
Output tensor with shape [bsz, num_heads, q_len, head_dim]
218+
"""
219+
bsz, num_heads, _, head_dim = q.shape
194220

195221
assert head_dim in {32, 64, 128, 256}
196-
assert context_lengths.shape[0] == block_tables.shape[0] == bsz, (
222+
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
197223
f"Got incompatible batch size (number of seqs):\n"
198-
f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
224+
f" KV seq lengths bsz {kv_seq_len.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
199225
f"batch size {bsz}"
200226
)
201227
assert k_cache.size(-1) == v_cache.size(-1) == block_size, (
202228
f"Got incompatible block size on kv caches:\n"
203229
f" assigned block_size {block_size}, k_cache block_size {k_cache.size(-1)}, "
204230
f"v_cache block_size {v_cache.size(-1)}"
205231
)
206-
# NOTE `context_lengths` records the (kv) sequence lengths incorporating past kv sequence lengths.
207-
bsz = context_lengths.size(0) # e.g. the number of seqs
208-
max_seq_len = context_lengths.max().item()
209-
sm_scale = 1.0 / (head_dim**0.5)
210232

211233
# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v
212234
# For now, BLOCK_KV is supposed to be equivalent with the size of physical cache block (i.e.`block_size`)
213235
assert block_size in {16, 32, 64, 128}
214236
BLOCK_KV = block_size
215237

216-
kv_max_split_num = (max_seq_len + BLOCK_KV - 1) // BLOCK_KV
217-
mid_o = torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
218-
mid_o_lse = torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
219-
220-
if q.dim() == 4:
221-
assert q.size(1) == 1, f"q_len is supposed to be 1 but is {q.size(1)}"
222-
q = q.squeeze(1)
238+
sm_scale = 1.0 / (head_dim**0.5) if sm_scale is None else sm_scale
239+
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
240+
# For compatibility (TODO revise modeling in future)
241+
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
242+
mid_output = (
243+
torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
244+
if mid_output is None
245+
else mid_output
246+
)
247+
mid_output_lse = (
248+
torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
249+
if mid_output_lse is None
250+
else mid_output_lse
251+
)
223252

224-
grid = (bsz, num_heads, triton.cdiv(max_seq_len, BLOCK_KV))
253+
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
225254
_flash_decoding_fwd_kernel[grid](
226255
q,
227256
k_cache,
228257
v_cache,
229258
block_tables,
230-
mid_o,
231-
mid_o_lse,
232-
context_lengths,
259+
mid_output,
260+
mid_output_lse,
261+
kv_seq_len,
233262
q.stride(0),
234263
q.stride(1),
235264
q.stride(2),
265+
q.stride(3),
236266
k_cache.stride(0),
237267
k_cache.stride(1),
238268
k_cache.stride(2),
239269
k_cache.stride(3),
240270
block_tables.stride(0),
241271
block_tables.stride(1),
242-
mid_o.stride(0),
243-
mid_o.stride(1),
244-
mid_o.stride(2),
245-
mid_o.stride(3),
246-
mid_o_lse.stride(0),
247-
mid_o_lse.stride(1),
248-
mid_o_lse.stride(2),
272+
mid_output.stride(0),
273+
mid_output.stride(1),
274+
mid_output.stride(2),
275+
mid_output.stride(3),
276+
mid_output_lse.stride(0),
277+
mid_output_lse.stride(1),
278+
mid_output_lse.stride(2),
249279
sm_scale,
250-
KV_GROUPS=num_kv_group,
280+
KV_GROUPS=kv_group_num,
251281
BLOCK_KV=block_size,
252282
BLOCK_SIZE=block_size,
253283
HEAD_DIM=head_dim,
254284
)
255285

256-
output = torch.zeros_like(q)
257-
output = output.view(-1, output.size(-2), output.size(-1))
286+
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped
258287

259288
grid = (bsz, num_heads)
260289
_flash_decoding_fwd_reduce_kernel[grid](
261-
mid_o,
262-
mid_o_lse,
290+
mid_output,
291+
mid_output_lse,
263292
output,
264-
context_lengths,
265-
mid_o.stride(0),
266-
mid_o.stride(1),
267-
mid_o.stride(2),
268-
mid_o.stride(3),
269-
mid_o_lse.stride(0),
270-
mid_o_lse.stride(1),
271-
mid_o_lse.stride(2),
293+
kv_seq_len,
294+
mid_output.stride(0),
295+
mid_output.stride(1),
296+
mid_output.stride(2),
297+
mid_output.stride(3),
298+
mid_output_lse.stride(0),
299+
mid_output_lse.stride(1),
300+
mid_output_lse.stride(2),
272301
output.stride(0),
273302
output.stride(1),
274303
output.stride(2),
304+
output.stride(3),
275305
BLOCK_KV=block_size,
276306
HEAD_DIM=head_dim,
277307
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
3+
from colossalai.context.singleton_meta import SingletonMeta
4+
from colossalai.utils import get_current_device
5+
6+
7+
class FDIntermTensors(metaclass=SingletonMeta):
8+
"""Singleton class to hold tensors used for storing intermediate values in flash-decoding.
9+
For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv)
10+
"""
11+
12+
def __init__(self):
13+
self._tensors_initialized = False
14+
15+
@property
16+
def is_initialized(self):
17+
return self._tensors_initialized
18+
19+
@property
20+
def mid_output(self):
21+
assert self.is_initialized, "Intermediate tensors not initialized yet"
22+
return self._mid_output
23+
24+
@property
25+
def mid_output_lse(self):
26+
assert self.is_initialized, "Intermediate tensors not initialized yet"
27+
return self._mid_output_lse
28+
29+
def initialize(
30+
self,
31+
max_batch_size: int,
32+
num_attn_heads: int,
33+
kv_max_split_num: int,
34+
head_dim: int,
35+
dtype: torch.dtype = torch.float32,
36+
device: torch.device = get_current_device(),
37+
) -> None:
38+
"""Initialize tensors.
39+
40+
Args:
41+
max_batch_size (int): The maximum batch size over all the model forward.
42+
This could be greater than the batch size in attention forward func when using dynamic batch size.
43+
num_attn_heads (int)): Number of attention heads.
44+
kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm.
45+
**The maximum length/size of blocks splitted on kv should be the kv cache block size.**
46+
head_dim (int): Head dimension.
47+
dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors.
48+
device (torch.device, optional): Device used to initialize intermediate tensors.
49+
"""
50+
assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized."
51+
52+
self._mid_output = torch.empty(
53+
size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device
54+
)
55+
self._mid_output_lse = torch.empty(
56+
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
57+
)
58+
self._tensors_initialized = True

0 commit comments

Comments
 (0)