@@ -79,7 +79,6 @@ def baichuan_rmsnorm_forward(
7979 TypeError (
8080 "Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
8181 )
82-
8382 if use_cuda_kernel :
8483 if residual is not None :
8584 inference_ops .fused_add_rms_layernorm (hidden_states , residual , self .weight .data , eps )
@@ -137,6 +136,7 @@ def __init__(
137136 self .alibi_slopes = get_alibi_slopes (config .num_attention_heads , device = attn_qproj_w .device )[
138137 slopes_start : slopes_start + num_heads
139138 ].contiguous ()
139+ self .alibi_slopes = nn .Parameter (self .alibi_slopes )
140140
141141 @staticmethod
142142 def from_native_module (
@@ -268,19 +268,13 @@ def forward(
268268 block_size = k_cache .size (- 2 )
269269
270270 if is_prompts :
271- if (
272- not is_verifier
273- and use_cuda_kernel
274- and query_states .dtype != torch .float32
275- and use_flash_attn2
276- and not self .use_alibi_attn
277- ):
271+ if not is_verifier and use_cuda_kernel and query_states .dtype != torch .float32 and use_flash_attn2 :
278272 # flash attn 2 currently only supports FP16/BF16.
279- inference_ops .rotary_embedding (query_states , key_states , cos_sin [0 ], cos_sin [1 ], high_precision )
273+ if not self .use_alibi_attn :
274+ inference_ops .rotary_embedding (query_states , key_states , cos_sin [0 ], cos_sin [1 ], high_precision )
280275 inference_ops .context_kv_cache_memcpy (
281276 key_states , value_states , k_cache , v_cache , sequence_lengths , cu_seqlens , block_tables , kv_seq_len
282277 )
283-
284278 attn_output = flash_attn_varlen_func (
285279 query_states ,
286280 key_states ,
@@ -292,6 +286,7 @@ def forward(
292286 dropout_p = 0.0 ,
293287 softmax_scale = sm_scale ,
294288 causal = True ,
289+ alibi_slopes = self .alibi_slopes ,
295290 )
296291 attn_output = attn_output .view (token_nums , - 1 )
297292 else :
0 commit comments