Skip to content

Commit 5b17b01

Browse files
author
niushengxiao
committed
opt: optimatize fp8kv performance
1 parent 94c4bb8 commit 5b17b01

File tree

3 files changed

+66
-68
lines changed

3 files changed

+66
-68
lines changed

lightllm/common/basemodel/triton_kernel/q_per_head_fp8_quant.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88
def _per_head_max_reduce_kernel(
99
Q,
1010
Scales,
11-
BatchIds,
1211
StartLoc,
1312
stride_q_t,
1413
stride_q_h,
1514
stride_scales_b,
16-
SET_BATCH_IDS: tl.constexpr,
1715
FP8_MAX: tl.constexpr,
1816
BLOCK_T: tl.constexpr,
1917
BLOCK_D: tl.constexpr,
@@ -32,8 +30,6 @@ def _per_head_max_reduce_kernel(
3230
mask = (t_idx[:, None] < end_loc) & (q_range[None, :] < stride_q_h)
3331
q_vals = tl.load(q_ptrs, mask=mask, other=0.0)
3432
max_val = tl.maximum(tl.max(q_vals.abs()), max_val)
35-
if SET_BATCH_IDS:
36-
tl.store(BatchIds + t_idx, b_id, mask=t_idx < end_loc)
3733

3834
scale = tl.where(max_val > 0, max_val / FP8_MAX, 1.0)
3935
scale_ptr = Scales + b_id * stride_scales_b + h_id
@@ -73,29 +69,29 @@ def _apply_quantization_kernel(
7369

7470

7571
@torch.no_grad()
76-
def q_per_head_fp8_quant(q, seq_lens, b1_start_loc):
72+
def q_per_head_fp8_quant(q, seq_lens, b1_start_loc, scale_out=None, batch_ids=None):
7773
T, H, D = q.shape
7874
B = seq_lens.shape[0]
79-
device = q.device
80-
81-
q_out = torch.empty_like(q, dtype=torch.float8_e4m3fn)
82-
scales = torch.empty((B, H), dtype=torch.float32, device=device)
83-
batch_ids = torch.zeros((T,), dtype=torch.int32, device=device)
8475

8576
BLOCK_D = triton.next_power_of_2(D)
8677
BLOCK_T = 256
8778
num_warps = 4
8879
num_stages = 2
80+
81+
q_out = torch.empty_like(q, dtype=torch.float8_e4m3fn)
82+
if scale_out is None:
83+
scale_out = torch.empty((B, H), dtype=torch.float32, device=q.device)
84+
if batch_ids is None:
85+
batch_ids = torch.repeat_interleave(torch.arange(B, device=q.device), seq_lens)
86+
8987
_per_head_max_reduce_kernel[(B, H)](
9088
q,
91-
scales,
92-
batch_ids,
89+
scale_out,
9390
b1_start_loc,
9491
q.stride(0),
9592
q.stride(1),
96-
scales.stride(0),
93+
scale_out.stride(0),
9794
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
98-
SET_BATCH_IDS=B > 1,
9995
BLOCK_T=BLOCK_T,
10096
BLOCK_D=BLOCK_D,
10197
num_warps=num_warps,
@@ -106,19 +102,19 @@ def q_per_head_fp8_quant(q, seq_lens, b1_start_loc):
106102
q,
107103
q_out,
108104
batch_ids,
109-
scales,
105+
scale_out,
110106
q.stride(0),
111107
q.stride(1),
112108
q_out.stride(0),
113109
q_out.stride(1),
114-
scales.stride(0),
110+
scale_out.stride(0),
115111
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
116112
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
117113
BLOCK_D=BLOCK_D,
118114
num_warps=num_warps,
119115
num_stages=num_stages,
120116
)
121-
return q_out, scales
117+
return q_out, scale_out
122118

123119

124120
def ref_q_per_head_fp8_quant(q, seq_lens):

lightllm/models/llama/flashattention_infer_struct.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,16 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
2828
if self.is_prefill:
2929
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
3030
self.cu_seqlens_k = self.b1_cu_kv_seq_len.int()
31-
self.page_table = torch.empty((self.batch_size, self.max_seq_len), dtype=torch.int32).to(input_ids.device)
31+
self.page_table = torch.empty(
32+
(self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device
33+
)
3234
self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len])
35+
if "calibration_fp8kv" in model.mode:
36+
device = input_ids.device
37+
self.q_scale = torch.empty(
38+
(self.batch_size, self.mem_manager.head_num), dtype=torch.float32, device=device
39+
)
40+
self.batch_ids = torch.repeat_interleave(torch.arange(self.batch_size, device=device), self.b_q_seq_len)
3341
else:
3442
# Meta information of flashattention for decoding
3543
self.cu_seqlens_q = self.b1_cu_q_seq_len.int()
@@ -43,12 +51,38 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
4351
: self.batch_size * model.graph_max_len_in_batch
4452
].reshape(self.batch_size, model.graph_max_len_in_batch)
4553
else:
46-
self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to(
47-
input_ids.device
54+
self.page_table = torch.empty(
55+
(self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device
4856
)
4957

5058
self.page_table[:, :max_seq_len_k].copy_(
5159
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k]
5260
)
5361
self.page_table[:, max_seq_len_k:].fill_(0)
62+
63+
if "calibration_fp8kv" in model.mode:
64+
offline_scales = self.mem_manager.offline_fp8_quant_manager.scales
65+
head_num = self.mem_manager.head_num
66+
self.k_descale = (
67+
offline_scales[:, :head_num]
68+
.view(-1, 1, head_num)
69+
.expand(offline_scales.shape[0], self.batch_size, head_num)
70+
if offline_scales is not None
71+
else torch.ones(
72+
(self.mem_manager.layer_num, self.batch_size, head_num),
73+
dtype=torch.float32,
74+
device=input_ids.device,
75+
)
76+
)
77+
self.v_descale = (
78+
offline_scales[:, head_num:]
79+
.view(-1, 1, head_num)
80+
.expand(offline_scales.shape[0], self.batch_size, head_num)
81+
if offline_scales is not None
82+
else torch.ones(
83+
(self.mem_manager.layer_num, self.batch_size, head_num),
84+
dtype=torch.float32,
85+
device=input_ids.device,
86+
)
87+
)
5488
return

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 16 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _bind_attention(self):
132132
LlamaTransformerLayerInfer._token_decode_attention_flashinfer_fp8, self
133133
)
134134
else:
135-
raise Exception("fp8 kvcache only support fa3 and flashinfer backend")
135+
raise Exception("calibration fp8 kvcache only support fa3 and flashinfer backend")
136136
elif "triton_flashdecoding" in self.mode:
137137
self._token_attention_kernel = partial(
138138
LlamaTransformerLayerInfer._token_decode_attention_flashdecoding, self
@@ -333,6 +333,13 @@ def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionSt
333333
def _context_attention_flashattention_fp8(
334334
self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None
335335
):
336+
q, q_scale = q_per_head_fp8_quant(
337+
q.view(q.shape[0], self.tp_k_head_num_, -1),
338+
infer_state.b_seq_len,
339+
infer_state.cu_seqlens_q,
340+
infer_state.q_scale,
341+
infer_state.batch_ids,
342+
)
336343
cache_k = (
337344
(infer_state.mem_manager.kv_buffer[self.layer_num_][:, : self.tp_k_head_num_, :])
338345
.reshape(-1, 1, self.tp_k_head_num_, self.head_dim_)
@@ -347,43 +354,21 @@ def _context_attention_flashattention_fp8(
347354
.reshape(-1, 1, self.tp_v_head_num_, self.head_dim_)
348355
.view(torch.float8_e4m3fn)
349356
)
350-
q, q_scale = q_per_head_fp8_quant(
351-
q.view(q.shape[0], self.tp_k_head_num_, -1),
352-
infer_state.b_seq_len,
353-
infer_state.cu_seqlens_q,
354-
)
355-
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
356-
q_descale = q_scale
357-
ones_scales = torch.ones((infer_state.batch_size, self.tp_k_head_num_), device=q.device, dtype=torch.float32)
358-
offline_scales = infer_state.mem_manager.offline_fp8_quant_manager.scales
359-
k_descale = (
360-
offline_scales[self.layer_num_][: self.tp_k_head_num_].expand(infer_state.batch_size, self.tp_k_head_num_)
361-
if offline_scales is not None
362-
else ones_scales
363-
)
364-
v_descale = (
365-
offline_scales[self.layer_num_][self.tp_k_head_num_ :].expand(infer_state.batch_size, self.tp_k_head_num_)
366-
if offline_scales is not None
367-
else ones_scales
368-
)
369-
Lq = q.shape[-1]
370-
sm_scale = 1.0 / (Lq ** 0.5)
371357
o = flash_attn_with_kvcache(
372-
q=q,
358+
q=q.view(-1, self.tp_q_head_num_, self.head_dim_),
373359
k_cache=cache_k,
374360
v_cache=cache_v,
375361
page_table=infer_state.page_table,
376362
cache_seqlens=infer_state.b_seq_len,
377363
cu_seqlens_q=infer_state.cu_seqlens_q,
378364
cu_seqlens_k_new=infer_state.cu_seqlens_k,
379365
max_seqlen_q=infer_state.q_max_seq_len,
380-
softmax_scale=sm_scale,
381366
causal=True,
382367
window_size=(-1, -1),
383368
softcap=0.0,
384-
q_descale=q_descale,
385-
k_descale=k_descale,
386-
v_descale=v_descale,
369+
q_descale=q_scale,
370+
k_descale=infer_state.k_descale[self.layer_num_],
371+
v_descale=infer_state.v_descale[self.layer_num_],
387372
return_softmax_lse=False,
388373
)
389374
return o
@@ -867,38 +852,21 @@ def _token_decode_attention_flashattention_fp8(
867852
.view(torch.float8_e4m3fn)
868853
)
869854
q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * self.tp_k_head_num_, -1), use_per_token_if_dynamic=True)
870-
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
871-
q_descale = q_scale.view(q.shape[0], self.tp_k_head_num_)
872-
ones_scales = torch.ones((infer_state.batch_size, self.tp_k_head_num_), device=q.device, dtype=torch.float32)
873-
offline_scales = infer_state.mem_manager.offline_fp8_quant_manager.scales
874-
k_descale = (
875-
offline_scales[self.layer_num_][: self.tp_k_head_num_].expand(infer_state.batch_size, self.tp_k_head_num_)
876-
if offline_scales is not None
877-
else ones_scales
878-
)
879-
v_descale = (
880-
offline_scales[self.layer_num_][self.tp_k_head_num_ :].expand(infer_state.batch_size, self.tp_k_head_num_)
881-
if offline_scales is not None
882-
else ones_scales
883-
)
884-
Lq = q.shape[-1]
885-
sm_scale = 1.0 / (Lq ** 0.5)
886855
o = flash_attn_with_kvcache(
887-
q=q,
856+
q=q.view(-1, self.tp_q_head_num_, self.head_dim_),
888857
k_cache=cache_k,
889858
v_cache=cache_v,
890859
page_table=infer_state.page_table,
891860
cache_seqlens=infer_state.b_seq_len,
892861
cu_seqlens_q=infer_state.cu_seqlens_q,
893862
cu_seqlens_k_new=infer_state.cu_seqlens_k,
894863
max_seqlen_q=1,
895-
softmax_scale=sm_scale,
896864
causal=False,
897865
window_size=(-1, -1),
898866
softcap=0.0,
899-
q_descale=q_descale,
900-
k_descale=k_descale,
901-
v_descale=v_descale,
867+
q_descale=q_scale.view(infer_state.batch_size, self.tp_k_head_num_),
868+
k_descale=infer_state.k_descale[self.layer_num_],
869+
v_descale=infer_state.v_descale[self.layer_num_],
902870
return_softmax_lse=False,
903871
)
904872
return o

0 commit comments

Comments
 (0)