@@ -132,7 +132,7 @@ def _bind_attention(self):
132
132
LlamaTransformerLayerInfer ._token_decode_attention_flashinfer_fp8 , self
133
133
)
134
134
else :
135
- raise Exception ("fp8 kvcache only support fa3 and flashinfer backend" )
135
+ raise Exception ("calibration fp8 kvcache only support fa3 and flashinfer backend" )
136
136
elif "triton_flashdecoding" in self .mode :
137
137
self ._token_attention_kernel = partial (
138
138
LlamaTransformerLayerInfer ._token_decode_attention_flashdecoding , self
@@ -333,6 +333,13 @@ def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionSt
333
333
def _context_attention_flashattention_fp8 (
334
334
self , q , kv , infer_state : FlashAttentionStateInfo , layer_weight , out = None
335
335
):
336
+ q , q_scale = q_per_head_fp8_quant (
337
+ q .view (q .shape [0 ], self .tp_k_head_num_ , - 1 ),
338
+ infer_state .b_seq_len ,
339
+ infer_state .cu_seqlens_q ,
340
+ infer_state .q_scale ,
341
+ infer_state .batch_ids ,
342
+ )
336
343
cache_k = (
337
344
(infer_state .mem_manager .kv_buffer [self .layer_num_ ][:, : self .tp_k_head_num_ , :])
338
345
.reshape (- 1 , 1 , self .tp_k_head_num_ , self .head_dim_ )
@@ -347,43 +354,21 @@ def _context_attention_flashattention_fp8(
347
354
.reshape (- 1 , 1 , self .tp_v_head_num_ , self .head_dim_ )
348
355
.view (torch .float8_e4m3fn )
349
356
)
350
- q , q_scale = q_per_head_fp8_quant (
351
- q .view (q .shape [0 ], self .tp_k_head_num_ , - 1 ),
352
- infer_state .b_seq_len ,
353
- infer_state .cu_seqlens_q ,
354
- )
355
- q = q .view (- 1 , self .tp_q_head_num_ , self .head_dim_ )
356
- q_descale = q_scale
357
- ones_scales = torch .ones ((infer_state .batch_size , self .tp_k_head_num_ ), device = q .device , dtype = torch .float32 )
358
- offline_scales = infer_state .mem_manager .offline_fp8_quant_manager .scales
359
- k_descale = (
360
- offline_scales [self .layer_num_ ][: self .tp_k_head_num_ ].expand (infer_state .batch_size , self .tp_k_head_num_ )
361
- if offline_scales is not None
362
- else ones_scales
363
- )
364
- v_descale = (
365
- offline_scales [self .layer_num_ ][self .tp_k_head_num_ :].expand (infer_state .batch_size , self .tp_k_head_num_ )
366
- if offline_scales is not None
367
- else ones_scales
368
- )
369
- Lq = q .shape [- 1 ]
370
- sm_scale = 1.0 / (Lq ** 0.5 )
371
357
o = flash_attn_with_kvcache (
372
- q = q ,
358
+ q = q . view ( - 1 , self . tp_q_head_num_ , self . head_dim_ ) ,
373
359
k_cache = cache_k ,
374
360
v_cache = cache_v ,
375
361
page_table = infer_state .page_table ,
376
362
cache_seqlens = infer_state .b_seq_len ,
377
363
cu_seqlens_q = infer_state .cu_seqlens_q ,
378
364
cu_seqlens_k_new = infer_state .cu_seqlens_k ,
379
365
max_seqlen_q = infer_state .q_max_seq_len ,
380
- softmax_scale = sm_scale ,
381
366
causal = True ,
382
367
window_size = (- 1 , - 1 ),
383
368
softcap = 0.0 ,
384
- q_descale = q_descale ,
385
- k_descale = k_descale ,
386
- v_descale = v_descale ,
369
+ q_descale = q_scale ,
370
+ k_descale = infer_state . k_descale [ self . layer_num_ ] ,
371
+ v_descale = infer_state . v_descale [ self . layer_num_ ] ,
387
372
return_softmax_lse = False ,
388
373
)
389
374
return o
@@ -867,38 +852,21 @@ def _token_decode_attention_flashattention_fp8(
867
852
.view (torch .float8_e4m3fn )
868
853
)
869
854
q , q_scale = scaled_fp8_quant (q .view (q .shape [0 ] * self .tp_k_head_num_ , - 1 ), use_per_token_if_dynamic = True )
870
- q = q .view (- 1 , self .tp_q_head_num_ , self .head_dim_ )
871
- q_descale = q_scale .view (q .shape [0 ], self .tp_k_head_num_ )
872
- ones_scales = torch .ones ((infer_state .batch_size , self .tp_k_head_num_ ), device = q .device , dtype = torch .float32 )
873
- offline_scales = infer_state .mem_manager .offline_fp8_quant_manager .scales
874
- k_descale = (
875
- offline_scales [self .layer_num_ ][: self .tp_k_head_num_ ].expand (infer_state .batch_size , self .tp_k_head_num_ )
876
- if offline_scales is not None
877
- else ones_scales
878
- )
879
- v_descale = (
880
- offline_scales [self .layer_num_ ][self .tp_k_head_num_ :].expand (infer_state .batch_size , self .tp_k_head_num_ )
881
- if offline_scales is not None
882
- else ones_scales
883
- )
884
- Lq = q .shape [- 1 ]
885
- sm_scale = 1.0 / (Lq ** 0.5 )
886
855
o = flash_attn_with_kvcache (
887
- q = q ,
856
+ q = q . view ( - 1 , self . tp_q_head_num_ , self . head_dim_ ) ,
888
857
k_cache = cache_k ,
889
858
v_cache = cache_v ,
890
859
page_table = infer_state .page_table ,
891
860
cache_seqlens = infer_state .b_seq_len ,
892
861
cu_seqlens_q = infer_state .cu_seqlens_q ,
893
862
cu_seqlens_k_new = infer_state .cu_seqlens_k ,
894
863
max_seqlen_q = 1 ,
895
- softmax_scale = sm_scale ,
896
864
causal = False ,
897
865
window_size = (- 1 , - 1 ),
898
866
softcap = 0.0 ,
899
- q_descale = q_descale ,
900
- k_descale = k_descale ,
901
- v_descale = v_descale ,
867
+ q_descale = q_scale . view ( infer_state . batch_size , self . tp_k_head_num_ ) ,
868
+ k_descale = infer_state . k_descale [ self . layer_num_ ] ,
869
+ v_descale = infer_state . v_descale [ self . layer_num_ ] ,
902
870
return_softmax_lse = False ,
903
871
)
904
872
return o
0 commit comments