@@ -212,7 +212,6 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
212212 self .use_qk_norm = config .use_qk_norm
213213 self .use_conv2d = False
214214
215- assert not self .use_qk_norm , "QK norm not supported in static attention yet"
216215 self .wqs = nn .ModuleList (
217216 [
218217 nn .Linear (self .dim , self .head_dim , bias = self .attention_qkv_bias )
@@ -241,6 +240,13 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
241240 self .wo = nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
242241 self .rope = _Rope (rope .params .use_hf_rope )
243242
243+ if self .use_qk_norm :
244+ self .q_norm = torch .nn .RMSNorm (self .head_dim , config .norm_eps )
245+ self .k_norm = torch .nn .RMSNorm (self .head_dim , config .norm_eps )
246+ else :
247+ self .q_norm = torch .nn .Identity ()
248+ self .k_norm = torch .nn .Identity ()
249+
244250 def forward (
245251 self ,
246252 x : torch .Tensor ,
@@ -275,6 +281,10 @@ def from_conv2ds(ts):
275281 new_ks = from_conv2ds (new_ks )
276282 new_vs = from_conv2ds (new_vs )
277283
284+ if self .use_qk_norm :
285+ new_qs = [self .q_norm (q ) for q in new_qs ]
286+ new_ks = [self .k_norm (k ) for k in new_ks ]
287+
278288 new_qs = [self .rope (q , freqs_cos , freqs_sin ) for q in new_qs ]
279289 new_ks = [self .rope (k , freqs_cos , freqs_sin ) for k in new_ks ]
280290 all_ks = []
@@ -325,6 +335,13 @@ def load_weights_from_attention_mha(self, other: AttentionMHA):
325335
326336 self .wo .weight .data .copy_ (other .wo .weight )
327337
338+ if other .use_qk_norm :
339+ self .use_qk_norm = True
340+ self .q_norm = torch .nn .RMSNorm (other .q_norm_fn .dim , other .q_norm_fn .eps )
341+ self .q_norm .load_state_dict (other .q_norm_fn .state_dict ())
342+ self .k_norm = torch .nn .RMSNorm (other .k_norm_fn .dim , other .k_norm_fn .eps )
343+ self .k_norm .load_state_dict (other .k_norm_fn .state_dict ())
344+
328345 def linear_to_conv2d (self ):
329346 def transfer_weight (linear , conv2d ):
330347 conv2d .weight .data .copy_ (linear .weight [:, :, None , None ])
0 commit comments