Skip to content

Commit

Permalink
add conversion of cross-attention layers
Browse files Browse the repository at this point in the history
  • Loading branch information
molbap committed Jul 5, 2024
1 parent 4b88aad commit 462bf3c
Showing 1 changed file with 191 additions and 91 deletions.
282 changes: 191 additions & 91 deletions src/transformers/models/mllama/convert_mllama_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import torch

from transformers import LlamaConfig, LlamaForCausalLM, CLIPVisionConfig, LlamaTokenizer, PreTrainedTokenizerFast
from transformers import LlamaConfig, LlamaForCausalLM, MllamaConfig, CLIPVisionConfig, LlamaTokenizer, PreTrainedTokenizerFast
from transformers.convert_slow_tokenizer import TikTokenConverter


Expand Down Expand Up @@ -83,7 +83,10 @@ def write_model(
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
params = params.get("model", params)
n_layers = params["n_layers"]
n_layers = params["n_layers"] # language model self-attention layers
n_layers_cross_attention = 20 # language model cross-attention layers
n_layers_vision_transformer = 32 # vision model 1st transformer layers
n_layers_global_transformer = 8 # global transformer vision layers
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
Expand Down Expand Up @@ -127,119 +130,216 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
for i in range(num_shards)
]


# weights are loaded. Now, we convert them into a state_dict.

# first, language model weights.

# we start with self-attention layers.

param_count = 0
index_dict = {"weight_map": {}}
for layer_i in range(n_layers):
filename = f"pytorch_language_model-{layer_i + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
# Unsharded
state_dict = {
f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded.pop(f"text_model.layers.{layer_i}.attention.wq.weight"), n_heads=n_heads
),
f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded.pop(f"text_model.layers.{layer_i}.attention.wk.weight"),
n_heads=num_key_value_heads,
dim1=dim // num_local_key_value_heads,
),
f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight": loaded.pop(f"text_model.layers.{layer_i}.attention.wv.weight"),
f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight": loaded.pop(f"text_model.layers.{layer_i}.attention.wo.weight"),
f"model.language_model.layers.{layer_i}.mlp.gate_proj.weight": loaded.pop(f"text_model.layers.{layer_i}.feed_forward.w1.weight"),
f"model.language_model.layers.{layer_i}.mlp.down_proj.weight": loaded.pop(f"text_model.layers.{layer_i}.feed_forward.w2.weight"),
f"model.language_model.layers.{layer_i}.mlp.up_proj.weight": loaded.pop(f"text_model.layers.{layer_i}.feed_forward.w3.weight"),
f"model.language_model.layers.{layer_i}.input_layernorm.weight": loaded.pop(f"text_model.layers.{layer_i}.attention_norm.weight"),
f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": loaded.pop(f"text_model.layers.{layer_i}.ffn_norm.weight"),
}
else:
# Sharded
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.

state_dict = {
f"model.language_model.layers.{layer_i}.input_layernorm.weight": loaded[0].pop(
f"text_model.layers.{layer_i}.attention_norm.weight"
).clone(),
f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0].pop(
f"text_model.layers.{layer_i}.ffn_norm.weight"
).clone(),
}
state_dict[f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i].pop(f"text_model.layers.{layer_i}.attention.wq.weight").view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim),
n_heads=n_heads,
)
state_dict[f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
[
loaded[i].pop(f"text_model.layers.{layer_i}.attention.wk.weight").view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim),
num_key_value_heads,
key_value_dim,
dim,
)
state_dict[f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
# Sharded
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.

state_dict = {
f"model.language_model.layers.{layer_i}.input_layernorm.weight": loaded[0].pop(
f"text_model.layers.{layer_i}.attention_norm.weight"
).clone(),
f"model.language_model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0].pop(
f"text_model.layers.{layer_i}.ffn_norm.weight"
).clone(),
}
state_dict[f"model.language_model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i].pop(f"text_model.layers.{layer_i}.attention.wq.weight").view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim),
n_heads=n_heads,
)
state_dict[f"model.language_model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
[
loaded[i].pop(f"text_model.layers.{layer_i}.attention.wv.weight").view(
loaded[i].pop(f"text_model.layers.{layer_i}.attention.wk.weight").view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim)

state_dict[f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.layers.{layer_i}.attention.wo.weight") for i in range(num_shards)], dim=1
)
state_dict[f"model.language_model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w1.weight") for i in range(num_shards)], dim=0
)
state_dict[f"model.language_model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w2.weight") for i in range(num_shards)], dim=1
)
state_dict[f"model.language_model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w3.weight") for i in range(num_shards)], dim=0
)
).reshape(key_value_dim, dim),
num_key_value_heads,
key_value_dim,
dim,
)
state_dict[f"model.language_model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i].pop(f"text_model.layers.{layer_i}.attention.wv.weight").view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim)

state_dict[f"model.language_model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.layers.{layer_i}.attention.wo.weight") for i in range(num_shards)], dim=1
)
state_dict[f"model.language_model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w1.weight") for i in range(num_shards)], dim=0
)
state_dict[f"model.language_model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w2.weight") for i in range(num_shards)], dim=1
)
state_dict[f"model.language_model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.layers.{layer_i}.feed_forward.w3.weight") for i in range(num_shards)], dim=0
)

state_dict[f"model.language_model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))

# save embedding layer and norm

filename = f"pytorch_language_model-{n_layers + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
# Unsharded
state_dict = {
"model.language_model.embed_tokens.weight": loaded.pop("text_model.tok_embeddings.weight"),
"model.language_model.norm.weight": loaded.pop("text_model.norm.weight"),
"lm_head.weight": loaded.pop("text_model.output.weight"),
}
else:
concat_dim = 0 if llama_version == 3 else 1
state_dict = {
"model.language_model.norm.weight": loaded[0].pop("text_model.norm.weight"),
"model.language_model.embed_tokens.weight": torch.cat(
[loaded[i].pop("text_model.tok_embeddings.weight") for i in range(num_shards)], dim=concat_dim
),
"lm_head.weight": torch.cat([loaded[i].pop("text_model.output.weight") for i in range(num_shards)], dim=0),
}
concat_dim = 0 if llama_version == 3 else 1
state_dict = {
"model.language_model.norm.weight": loaded[0].pop("text_model.norm.weight"),
"model.language_model.embed_tokens.weight": torch.cat(
[loaded[i].pop("text_model.tok_embeddings.weight") for i in range(num_shards)], dim=concat_dim
),
"lm_head.weight": torch.cat([loaded[i].pop("text_model.output.weight") for i in range(num_shards)], dim=0),
}

for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))

# Then, cross-attention layers from the language model

for xattn_layer_i in range(n_layers_cross_attention):
cross_attentions_filename = f"pytorch_language_model_xattn-{xattn_layer_i + 1}-of-{n_layers_cross_attention + 1}.bin"

# norms

state_dict = {
f"model.language_model.cross_attention_layers.{xattn_layer_i}.input_layernorm.weight": loaded[0].pop(
f"text_model.cross_attention_layers.{xattn_layer_i}.attention_norm.weight"
).clone(),
f"model.language_model.cross_attention_layers.{xattn_layer_i}.post_attention_layernorm.weight": loaded[0].pop(
f"text_model.cross_attention_layers.{xattn_layer_i}.ffn_norm.weight"
).clone(),
}

# projections

state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.feed_forward.w1.weight") for i in range(num_shards)], dim=0
)
state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.feed_forward.w2.weight") for i in range(num_shards)], dim=1
)
state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.feed_forward.w3.weight") for i in range(num_shards)], dim=0
)

# attention weights

state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.wq.weight").view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim),
n_heads=n_heads,
)
state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
[
loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.wk.weight").view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim),
num_key_value_heads,
key_value_dim,
dim,
)
state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.wv.weight").view(
num_local_key_value_heads, dims_per_head, dim
)
for i in range(num_shards)
],
dim=0,
).reshape(key_value_dim, dim)

state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.wo.weight") for i in range(num_shards)], dim=1
)

# gate attn (to mimic the loading hook from the authors)
ffn_gate = []
attn_gate = []
for i in range(num_shards):
attn_gate.append(loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.gate_attn"))
if attn_gate[i].dim() == 1:
attn_gate[i] = attn_gate[i, 0].view(1)
if attn_gate[i].dim() == 3:
attn_gate[i] = attn_gate[i].view(1)

ffn_gate.append(loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.gate_ffwd"))

if ffn_gate[i].dim() == 1:
ffn_gate[i] = ffn_gate[i, 0].view(1)
if ffn_gate[i].dim() == 3:
ffn_gate[i] = ffn_gate[i].view(1)
state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.attn_gate"] = torch.cat(
[attn_gate[i] for i in range(num_shards)], dim=0
)
state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.ffn_gate"] = torch.cat(
[ffn_gate[i] for i in range(num_shards)], dim=0
)

# q and k normalization weights (for cross-attention stability in training)

q_weight = []
k_weight = []
for i in range(num_shards):
q_weight.append(loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.inner_attention.q_norm.weight"))
k_weight.append(loaded[i].pop(f"text_model.cross_attention_layers.{xattn_layer_i}.attention.inner_attention.k_norm.weight"))
state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.attention.q_norm.weight"] = torch.cat(
[q_weight[i] for i in range(num_shards)], dim=0
)
state_dict[f"model.language_model.cross_attention_layers.{xattn_layer_i}.attention.k_norm.weight"] = torch.cat(
[k_weight[i] for i in range(num_shards)], dim=0
)


# save state dict of this layer

for k, v in state_dict.items():
index_dict["weight_map"][k] = cross_attentions_filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, cross_attentions_filename))

# then, converting the vision model double transformer (2 sets of layers, same width

# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
Expand Down

0 comments on commit 462bf3c

Please sign in to comment.