From 462bf3c0604f8328c0b3c2f0c686724d295957e0 Mon Sep 17 00:00:00 2001 From: Pablo Date: Fri, 5 Jul 2024 13:19:56 +0000 Subject: [PATCH] add conversion of cross-attention layers --- .../mllama/convert_mllama_weights_to_hf.py | 282 ++++++++++++------ 1 file changed, 191 insertions(+), 91 deletions(-) diff --git a/src/transformers/models/mllama/convert_mllama_weights_to_hf.py b/src/transformers/models/mllama/convert_mllama_weights_to_hf.py index cca9a0bec908fb..4bfe1005959cd1 100644 --- a/src/transformers/models/mllama/convert_mllama_weights_to_hf.py +++ b/src/transformers/models/mllama/convert_mllama_weights_to_hf.py @@ -21,7 +21,7 @@ import torch -from transformers import LlamaConfig, LlamaForCausalLM, CLIPVisionConfig, LlamaTokenizer, PreTrainedTokenizerFast +from transformers import LlamaConfig, LlamaForCausalLM, MllamaConfig, CLIPVisionConfig, LlamaTokenizer, PreTrainedTokenizerFast from transformers.convert_slow_tokenizer import TikTokenConverter @@ -83,7 +83,10 @@ def write_model( params = read_json(os.path.join(input_base_path, "params.json")) num_shards = NUM_SHARDS[model_size] params = params.get("model", params) - n_layers = params["n_layers"] + n_layers = params["n_layers"] # language model self-attention layers + n_layers_cross_attention = 20 # language model cross-attention layers + n_layers_vision_transformer = 32 # vision model 1st transformer layers + n_layers_global_transformer = 8 # global transformer vision layers n_heads = params["n_heads"] n_heads_per_shard = n_heads // num_shards dim = params["dim"] @@ -127,89 +130,77 @@ def permute(w, n_heads, dim1=dim, dim2=dim): torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") for i in range(num_shards) ] + + + # weights are loaded. Now, we convert them into a state_dict. + + # first, language model weights. + + # we start with self-attention layers. + param_count = 0 index_dict = {"weight_map": {}} for layer_i in range(n_layers): filename = f"pytorch_language_model-{layer_i + 1}-of-{n_layers + 1}.bin" - if num_shards == 1: - # Unsharded - state_dict = { - f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight": permute( - loaded.pop(f"text_model.layers.{layer_i}.attention.wq.weight"), n_heads=n_heads - ), - f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight": permute( - loaded.pop(f"text_model.layers.{layer_i}.attention.wk.weight"), - n_heads=num_key_value_heads, - dim1=dim // num_local_key_value_heads, - ), - f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight": loaded.pop(f"text_model.layers.{layer_i}.attention.wv.weight"), - f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight": loaded.pop(f"text_model.layers.{layer_i}.attention.wo.weight"), - f"model.language_model.layers.{layer_i}.mlp.gate_proj.weight": loaded.pop(f"text_model.layers.{layer_i}.feed_forward.w1.weight"), - f"model.language_model.layers.{layer_i}.mlp.down_proj.weight": loaded.pop(f"text_model.layers.{layer_i}.feed_forward.w2.weight"), - f"model.language_model.layers.{layer_i}.mlp.up_proj.weight": loaded.pop(f"text_model.layers.{layer_i}.feed_forward.w3.weight"), - f"model.language_model.layers.{layer_i}.input_layernorm.weight": loaded.pop(f"text_model.layers.{layer_i}.attention_norm.weight"), - f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": loaded.pop(f"text_model.layers.{layer_i}.ffn_norm.weight"), - } - else: - # Sharded - # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share - # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is - # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. - - state_dict = { - f"model.language_model.layers.{layer_i}.input_layernorm.weight": loaded[0].pop( - f"text_model.layers.{layer_i}.attention_norm.weight" - ).clone(), - f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0].pop( - f"text_model.layers.{layer_i}.ffn_norm.weight" - ).clone(), - } - state_dict[f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( - torch.cat( - [ - loaded[i].pop(f"text_model.layers.{layer_i}.attention.wq.weight").view(n_heads_per_shard, dims_per_head, dim) - for i in range(num_shards) - ], - dim=0, - ).reshape(dim, dim), - n_heads=n_heads, - ) - state_dict[f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( - torch.cat( - [ - loaded[i].pop(f"text_model.layers.{layer_i}.attention.wk.weight").view( - num_local_key_value_heads, dims_per_head, dim - ) - for i in range(num_shards) - ], - dim=0, - ).reshape(key_value_dim, dim), - num_key_value_heads, - key_value_dim, - dim, - ) - state_dict[f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + # Sharded + # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share + # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is + # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. + + state_dict = { + f"model.language_model.layers.{layer_i}.input_layernorm.weight": loaded[0].pop( + f"text_model.layers.{layer_i}.attention_norm.weight" + ).clone(), + f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0].pop( + f"text_model.layers.{layer_i}.ffn_norm.weight" + ).clone(), + } + state_dict[f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( + torch.cat( + [ + loaded[i].pop(f"text_model.layers.{layer_i}.attention.wq.weight").view(n_heads_per_shard, dims_per_head, dim) + for i in range(num_shards) + ], + dim=0, + ).reshape(dim, dim), + n_heads=n_heads, + ) + state_dict[f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( + torch.cat( [ - loaded[i].pop(f"text_model.layers.{layer_i}.attention.wv.weight").view( + loaded[i].pop(f"text_model.layers.{layer_i}.attention.wk.weight").view( num_local_key_value_heads, dims_per_head, dim ) for i in range(num_shards) ], dim=0, - ).reshape(key_value_dim, dim) - - state_dict[f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( - [loaded[i].pop(f"text_model.layers.{layer_i}.attention.wo.weight") for i in range(num_shards)], dim=1 - ) - state_dict[f"model.language_model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( - [loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w1.weight") for i in range(num_shards)], dim=0 - ) - state_dict[f"model.language_model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( - [loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w2.weight") for i in range(num_shards)], dim=1 - ) - state_dict[f"model.language_model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( - [loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w3.weight") for i in range(num_shards)], dim=0 - ) + ).reshape(key_value_dim, dim), + num_key_value_heads, + key_value_dim, + dim, + ) + state_dict[f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [ + loaded[i].pop(f"text_model.layers.{layer_i}.attention.wv.weight").view( + num_local_key_value_heads, dims_per_head, dim + ) + for i in range(num_shards) + ], + dim=0, + ).reshape(key_value_dim, dim) + + state_dict[f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [loaded[i].pop(f"text_model.layers.{layer_i}.attention.wo.weight") for i in range(num_shards)], dim=1 + ) + state_dict[f"model.language_model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w1.weight") for i in range(num_shards)], dim=0 + ) + state_dict[f"model.language_model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w2.weight") for i in range(num_shards)], dim=1 + ) + state_dict[f"model.language_model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w3.weight") for i in range(num_shards)], dim=0 + ) state_dict[f"model.language_model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq for k, v in state_dict.items(): @@ -217,29 +208,138 @@ def permute(w, n_heads, dim1=dim, dim2=dim): param_count += v.numel() torch.save(state_dict, os.path.join(tmp_model_path, filename)) + # save embedding layer and norm + filename = f"pytorch_language_model-{n_layers + 1}-of-{n_layers + 1}.bin" - if num_shards == 1: - # Unsharded - state_dict = { - "model.language_model.embed_tokens.weight": loaded.pop("text_model.tok_embeddings.weight"), - "model.language_model.norm.weight": loaded.pop("text_model.norm.weight"), - "lm_head.weight": loaded.pop("text_model.output.weight"), - } - else: - concat_dim = 0 if llama_version == 3 else 1 - state_dict = { - "model.language_model.norm.weight": loaded[0].pop("text_model.norm.weight"), - "model.language_model.embed_tokens.weight": torch.cat( - [loaded[i].pop("text_model.tok_embeddings.weight") for i in range(num_shards)], dim=concat_dim - ), - "lm_head.weight": torch.cat([loaded[i].pop("text_model.output.weight") for i in range(num_shards)], dim=0), - } + concat_dim = 0 if llama_version == 3 else 1 + state_dict = { + "model.language_model.norm.weight": loaded[0].pop("text_model.norm.weight"), + "model.language_model.embed_tokens.weight": torch.cat( + [loaded[i].pop("text_model.tok_embeddings.weight") for i in range(num_shards)], dim=concat_dim + ), + "lm_head.weight": torch.cat([loaded[i].pop("text_model.output.weight") for i in range(num_shards)], dim=0), + } for k, v in state_dict.items(): index_dict["weight_map"][k] = filename param_count += v.numel() torch.save(state_dict, os.path.join(tmp_model_path, filename)) + # Then, cross-attention layers from the language model + + for xattn_layer_i in range(n_layers_cross_attention): + cross_attentions_filename = f"pytorch_language_model_xattn-{xattn_layer_i + 1}-of-{n_layers_cross_attention + 1}.bin" + + # norms + + state_dict = { + f"model.language_model.cross_attention_layers.{xattn_layer_i}.input_layernorm.weight": loaded[0].pop( + f"text_model.cross_attention_layers.{xattn_layer_i}.attention_norm.weight" + ).clone(), + f"model.language_model.cross_attention_layers.{xattn_layer_i}.post_attention_layernorm.weight": loaded[0].pop( + f"text_model.cross_attention_layers.{xattn_layer_i}.ffn_norm.weight" + ).clone(), + } + + # projections + + state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.mlp.gate_proj.weight"] = torch.cat( + [loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.feed_forward.w1.weight") for i in range(num_shards)], dim=0 + ) + state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.mlp.down_proj.weight"] = torch.cat( + [loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.feed_forward.w2.weight") for i in range(num_shards)], dim=1 + ) + state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.mlp.up_proj.weight"] = torch.cat( + [loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.feed_forward.w3.weight") for i in range(num_shards)], dim=0 + ) + + # attention weights + + state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.self_attn.q_proj.weight"] = permute( + torch.cat( + [ + loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.wq.weight").view(n_heads_per_shard, dims_per_head, dim) + for i in range(num_shards) + ], + dim=0, + ).reshape(dim, dim), + n_heads=n_heads, + ) + state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.self_attn.k_proj.weight"] = permute( + torch.cat( + [ + loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.wk.weight").view( + num_local_key_value_heads, dims_per_head, dim + ) + for i in range(num_shards) + ], + dim=0, + ).reshape(key_value_dim, dim), + num_key_value_heads, + key_value_dim, + dim, + ) + state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.self_attn.v_proj.weight"] = torch.cat( + [ + loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.wv.weight").view( + num_local_key_value_heads, dims_per_head, dim + ) + for i in range(num_shards) + ], + dim=0, + ).reshape(key_value_dim, dim) + + state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.self_attn.o_proj.weight"] = torch.cat( + [loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.wo.weight") for i in range(num_shards)], dim=1 + ) + + # gate attn (to mimic the loading hook from the authors) + ffn_gate = [] + attn_gate = [] + for i in range(num_shards): + attn_gate.append(loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.gate_attn")) + if attn_gate[i].dim() == 1: + attn_gate[i] = attn_gate[i, 0].view(1) + if attn_gate[i].dim() == 3: + attn_gate[i] = attn_gate[i].view(1) + + ffn_gate.append(loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.gate_ffwd")) + + if ffn_gate[i].dim() == 1: + ffn_gate[i] = ffn_gate[i, 0].view(1) + if ffn_gate[i].dim() == 3: + ffn_gate[i] = ffn_gate[i].view(1) + state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.attn_gate"] = torch.cat( + [attn_gate[i] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.ffn_gate"] = torch.cat( + [ffn_gate[i] for i in range(num_shards)], dim=0 + ) + + # q and k normalization weights (for cross-attention stability in training) + + q_weight = [] + k_weight = [] + for i in range(num_shards): + q_weight.append(loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.inner_attention.q_norm.weight")) + k_weight.append(loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.inner_attention.k_norm.weight")) + state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.attention.q_norm.weight"] = torch.cat( + [q_weight[i] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.attention.k_norm.weight"] = torch.cat( + [k_weight[i] for i in range(num_shards)], dim=0 + ) + + + # save state dict of this layer + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = cross_attentions_filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, cross_attentions_filename)) + + # then, converting the vision model double transformer (2 sets of layers, same width + # Write configs index_dict["metadata"] = {"total_size": param_count * 2} write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))