Skip to content

Commit 2dedad6

Browse files
committed
Added support for repeated kv weights
1 parent d3c25b1 commit 2dedad6

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

export.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -276,20 +276,29 @@ def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32):
276276
return None
277277

278278
# Generate LlamaModel state_dict
279-
def permute_original(w, n_heads=llama_model.params.n_heads, dim1=llama_model.params.dim, dim2=llama_model.params.dim):
280-
return w.view(dim1, dim2).reshape(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
281-
282279
hf_state_dict = {}
283280

281+
# Sometimes we have repeated key values for the heads
282+
dim = llama_model.params.dim
283+
num_key_value_heads = llama_model.params.n_kv_heads
284+
n_rep = llama_model.params.n_heads // num_key_value_heads
285+
key_value_dim = dim // n_rep
286+
287+
# HuggingFace needs the weights permuted.
288+
# See: https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122
289+
def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim):
290+
return w.view(dim1, dim2).reshape(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
291+
284292
# Transfer weights from llama model to the HF state dictionary format
285293
hf_state_dict['model.embed_tokens.weight'] = llama_model.tok_embeddings.weight.clone().to(dtype)
286294
hf_state_dict['model.norm.weight'] = llama_model.norm.weight.clone().to(dtype)
287295

296+
# Add each layer's weights to the HF state dictionary
288297
for i, layer in enumerate(llama_model.layers):
289-
layer_id = layer.layer_id # Assuming llama.c layers have layer_id
298+
layer_id = layer.layer_id
290299
hf_state_dict[f'model.layers.{i}.input_layernorm.weight'] = llama_model.layers[layer_id].attention_norm.weight.clone().to(dtype)
291300
hf_state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wq.weight.clone()).to(dtype)
292-
hf_state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wk.weight.clone()).to(dtype)
301+
hf_state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wk.weight.clone(), num_key_value_heads, key_value_dim, dim).to(dtype)
293302
hf_state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = llama_model.layers[layer_id].attention.wv.weight.clone().to(dtype)
294303
hf_state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = llama_model.layers[layer_id].attention.wo.weight.clone().to(dtype)
295304
hf_state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = llama_model.layers[layer_id].ffn_norm.weight.clone().to(dtype)
@@ -318,8 +327,9 @@ def permute_original(w, n_heads=llama_model.params.n_heads, dim1=llama_model.par
318327
max_position_embeddings = llama_model.params.max_seq_len
319328
rms_norm_eps = llama_model.params.norm_eps
320329

321-
# TODO values for: pretraining_tp, initializer_range, use_cache,
322-
# tie_word_embeddings, rope_theta, and rope_scaling.
330+
# TODO check values for:
331+
# pretraining_tp, initializer_range, use_cache,
332+
# rope_theta, and rope_scaling.
323333

324334
config = LlamaConfig(
325335
vocab_size=vocab_size,

0 commit comments

Comments
 (0)