@@ -276,20 +276,29 @@ def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32):
276
276
return None
277
277
278
278
# Generate LlamaModel state_dict
279
- def permute_original (w , n_heads = llama_model .params .n_heads , dim1 = llama_model .params .dim , dim2 = llama_model .params .dim ):
280
- return w .view (dim1 , dim2 ).reshape (n_heads , dim1 // n_heads // 2 , 2 , dim2 ).transpose (1 , 2 ).reshape (dim1 , dim2 )
281
-
282
279
hf_state_dict = {}
283
280
281
+ # Sometimes we have repeated key values for the heads
282
+ dim = llama_model .params .dim
283
+ num_key_value_heads = llama_model .params .n_kv_heads
284
+ n_rep = llama_model .params .n_heads // num_key_value_heads
285
+ key_value_dim = dim // n_rep
286
+
287
+ # HuggingFace needs the weights permuted.
288
+ # See: https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122
289
+ def permute_original (w , n_heads = llama_model .params .n_heads , dim1 = dim , dim2 = dim ):
290
+ return w .view (dim1 , dim2 ).reshape (n_heads , dim1 // n_heads // 2 , 2 , dim2 ).transpose (1 , 2 ).reshape (dim1 , dim2 )
291
+
284
292
# Transfer weights from llama model to the HF state dictionary format
285
293
hf_state_dict ['model.embed_tokens.weight' ] = llama_model .tok_embeddings .weight .clone ().to (dtype )
286
294
hf_state_dict ['model.norm.weight' ] = llama_model .norm .weight .clone ().to (dtype )
287
295
296
+ # Add each layer's weights to the HF state dictionary
288
297
for i , layer in enumerate (llama_model .layers ):
289
- layer_id = layer .layer_id # Assuming llama.c layers have layer_id
298
+ layer_id = layer .layer_id
290
299
hf_state_dict [f'model.layers.{ i } .input_layernorm.weight' ] = llama_model .layers [layer_id ].attention_norm .weight .clone ().to (dtype )
291
300
hf_state_dict [f'model.layers.{ i } .self_attn.q_proj.weight' ] = permute_original (llama_model .layers [layer_id ].attention .wq .weight .clone ()).to (dtype )
292
- hf_state_dict [f'model.layers.{ i } .self_attn.k_proj.weight' ] = permute_original (llama_model .layers [layer_id ].attention .wk .weight .clone ()).to (dtype )
301
+ hf_state_dict [f'model.layers.{ i } .self_attn.k_proj.weight' ] = permute_original (llama_model .layers [layer_id ].attention .wk .weight .clone (), num_key_value_heads , key_value_dim , dim ).to (dtype )
293
302
hf_state_dict [f'model.layers.{ i } .self_attn.v_proj.weight' ] = llama_model .layers [layer_id ].attention .wv .weight .clone ().to (dtype )
294
303
hf_state_dict [f'model.layers.{ i } .self_attn.o_proj.weight' ] = llama_model .layers [layer_id ].attention .wo .weight .clone ().to (dtype )
295
304
hf_state_dict [f'model.layers.{ i } .post_attention_layernorm.weight' ] = llama_model .layers [layer_id ].ffn_norm .weight .clone ().to (dtype )
@@ -318,8 +327,9 @@ def permute_original(w, n_heads=llama_model.params.n_heads, dim1=llama_model.par
318
327
max_position_embeddings = llama_model .params .max_seq_len
319
328
rms_norm_eps = llama_model .params .norm_eps
320
329
321
- # TODO values for: pretraining_tp, initializer_range, use_cache,
322
- # tie_word_embeddings, rope_theta, and rope_scaling.
330
+ # TODO check values for:
331
+ # pretraining_tp, initializer_range, use_cache,
332
+ # rope_theta, and rope_scaling.
323
333
324
334
config = LlamaConfig (
325
335
vocab_size = vocab_size ,
0 commit comments