Skip to content

Commit ae43243

Browse files
committed
make conversion user dependent on model type, defaults for original paper models
1 parent a77d15b commit ae43243

File tree

1 file changed

+98
-68
lines changed

1 file changed

+98
-68
lines changed

src/transformers/models/mamba2/convert_mamba2_ssm_checkpoint_to_pytorch.py

Lines changed: 98 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616

1717
import argparse
1818
import json
19+
from functools import partial
1920
from os import path
20-
from typing import Dict, Optional, Tuple
21+
from typing import Dict, Optional
2122

2223
import torch
2324
from safetensors import safe_open
@@ -26,46 +27,26 @@
2627
from transformers import LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM, Mamba2TokenizerFast
2728

2829

29-
_MAMBA2_MODELS_DICT = {
30-
"codestral": {
31-
"hidden_size": "dim",
32-
"num_hidden_layers": "n_layers",
33-
"n_groups": "n_groups",
34-
"residual_in_fp32": "residual_in_fp32",
35-
"tie_word_embeddings": "tie_embeddings",
36-
"norm_before_gate": False,
37-
"vocab_size": "vocab_size",
38-
"pad_vocab_size_multiple": "pad_vocab_size_multiple",
39-
"bos_token_id": 0,
40-
"pad_token_id": 1,
41-
"eos_token_id": 2,
42-
},
43-
"base": {
44-
"hidden_size": "d_model",
45-
"num_hidden_layers": "n_layer",
46-
"n_groups": "ngroups",
47-
"residual_in_fp32": "residual_in_fp32",
48-
"tie_word_embeddings": "tie_embeddings",
49-
"norm_before_gate": False,
50-
"vocab_size": "vocab_size",
51-
"pad_vocab_size_multiple": "pad_vocab_size_multiple",
52-
"bos_token_id": 0,
53-
"pad_token_id": 0,
54-
"eos_token_id": 0,
55-
},
56-
}
30+
def load_state_dict_from_safetensors(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
31+
# Load weights and config from paths
32+
original_state_dict = {}
33+
with safe_open(path.join(mamba2_checkpoint_path, ckpt_name), framework="pt") as f:
34+
for k in f.keys():
35+
newk = k.removeprefix("model.")
36+
original_state_dict[newk] = f.get_tensor(k).clone()
37+
return original_state_dict
5738

5839

59-
def convert_ssm_config_to_hf_config(config_ssm: Dict) -> Tuple[Mamba2Config, bool]:
40+
def load_state_dict_from_torch(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
41+
return torch.load(path.join(mamba2_checkpoint_path, ckpt_name), map_location="cpu")
42+
43+
44+
def convert_ssm_config_to_hf_config(config_ssm: Dict, mamba2_model_dict: Dict) -> Mamba2Config:
6045
"""Convert a Mamba2Config from mamba_ssm to a Mamba2Config from here."""
6146
hf_config = Mamba2Config()
6247

63-
# Flag for codestral model
64-
is_not_codestral = "dim" not in config_ssm
65-
6648
# Switch to a different dict depending on model type
67-
config_key = "base" if is_not_codestral else "codestral"
68-
config_dict = _MAMBA2_MODELS_DICT[config_key]
49+
config_dict = mamba2_model_dict
6950

7051
# Set important values from config and recalculate other resulting entries
7152
hf_config.hidden_size = config_ssm[config_dict["hidden_size"]]
@@ -86,43 +67,82 @@ def convert_ssm_config_to_hf_config(config_ssm: Dict) -> Tuple[Mamba2Config, boo
8667
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
8768
hf_config.vocab_size = vocab_size
8869

89-
return hf_config, is_not_codestral
70+
return hf_config
9071

9172

92-
def load_state_dict_from_safetensors(mamba2_checkpoint_path: str) -> Dict[str, torch.Tensor]:
93-
# Load weights and config from paths
94-
original_state_dict = {}
95-
with safe_open(path.join(mamba2_checkpoint_path, "consolidated.safetensors"), framework="pt") as f:
96-
for k in f.keys():
97-
newk = k.removeprefix("model.")
98-
original_state_dict[newk] = f.get_tensor(k).clone()
99-
return original_state_dict
73+
def load_and_save_tokenizer(
74+
mamba2_model_type: str,
75+
output_dir: str,
76+
tokenizer_model_path: Optional[str] = None,
77+
) -> None:
78+
tokenizer = None
10079

80+
# Load tokenizer
81+
if tokenizer_model_path is not None and mamba2_model_type == "codestral":
82+
tokenizer_class = LlamaTokenizerFast
83+
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
84+
elif mamba2_model_type == "mamba_ssm":
85+
tokenizer = Mamba2TokenizerFast.from_pretrained("state-spaces/mamba-130m-hf")
10186

102-
def load_state_dict_from_torch(mamba2_checkpoint_path: str) -> Dict[str, torch.Tensor]:
103-
return torch.load(path.join(mamba2_checkpoint_path, "pytorch_model.bin"), map_location="cpu")
87+
# Save tokenizer
88+
if tokenizer is not None:
89+
tokenizer.save_pretrained(output_dir)
90+
91+
92+
_MAMBA2_MODELS_DICT = {
93+
"codestral": {
94+
"hidden_size": "dim",
95+
"num_hidden_layers": "n_layers",
96+
"n_groups": "n_groups",
97+
"residual_in_fp32": "residual_in_fp32",
98+
"tie_word_embeddings": "tie_embeddings",
99+
"norm_before_gate": False,
100+
"vocab_size": "vocab_size",
101+
"pad_vocab_size_multiple": "pad_vocab_size_multiple",
102+
"bos_token_id": 0,
103+
"pad_token_id": 1,
104+
"eos_token_id": 2,
105+
"config_name": "params.json",
106+
"load_state_dict": partial(load_state_dict_from_safetensors, ckpt_name="consolidated.safetensors"),
107+
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "codestral"),
108+
},
109+
"mamba_ssm": {
110+
"hidden_size": "d_model",
111+
"num_hidden_layers": "n_layer",
112+
"n_groups": "ngroups",
113+
"residual_in_fp32": "residual_in_fp32",
114+
"tie_word_embeddings": "tie_embeddings",
115+
"norm_before_gate": False,
116+
"vocab_size": "vocab_size",
117+
"pad_vocab_size_multiple": "pad_vocab_size_multiple",
118+
"bos_token_id": 0,
119+
"pad_token_id": 0,
120+
"eos_token_id": 0,
121+
"config_name": "config.json",
122+
"load_state_dict": partial(load_state_dict_from_torch, ckpt_name="pytorch_model.bin"),
123+
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "mamba_ssm"),
124+
},
125+
}
104126

105127

106128
def convert_mamba2_checkpoint_file_to_huggingface_model_file(
107-
mamba2_checkpoint_path: str, precision: str, output_dir: str, tokenizer_model_path: Optional[str] = None
129+
mamba2_checkpoint_path: str,
130+
mamba2_model_type: str,
131+
precision: str,
132+
output_dir: str,
133+
tokenizer_model_path: Optional[str] = None,
108134
) -> None:
135+
mamba2_model_dict = _MAMBA2_MODELS_DICT[mamba2_model_type]
136+
109137
# Load and save config based on name
110-
config_path = mamba2_checkpoint_path
111-
config_path = (
112-
path.join(config_path, "params.json")
113-
if path.isfile(path.join(config_path, "params.json"))
114-
else path.join(config_path, "config.json")
115-
)
138+
config_path = path.join(mamba2_checkpoint_path, mamba2_model_dict["config_name"])
116139
with open(config_path, "r", encoding="utf-8") as json_file:
117140
config = json.load(json_file)
118-
hf_config, is_not_codestral = convert_ssm_config_to_hf_config(config)
141+
hf_config = convert_ssm_config_to_hf_config(config_ssm=config, mamba2_model_dict=mamba2_model_dict)
119142
hf_config.save_pretrained(output_dir)
120143

121-
# Load state dict of the original model
122-
state_dict_load_function = load_state_dict_from_torch if is_not_codestral else load_state_dict_from_safetensors
123-
original_state_dict = state_dict_load_function(mamba2_checkpoint_path)
124-
125-
# Load and transfer to hf model
144+
# Load state dict of the original model and transfer to hf model
145+
original_state_dict = mamba2_model_dict["load_state_dict"](mamba2_checkpoint_path=mamba2_checkpoint_path)
126146
hf_model = Mamba2ForCausalLM(hf_config)
127147
hf_model.load_state_dict(original_state_dict)
128148

@@ -131,12 +151,7 @@ def convert_mamba2_checkpoint_file_to_huggingface_model_file(
131151
save_model(hf_model.to(dtype), path.join(output_dir, "model.safetensors"), metadata={"format": "pt"})
132152

133153
# Load and save tokenizer
134-
if tokenizer_model_path is not None and not is_not_codestral:
135-
tokenizer_class = LlamaTokenizerFast
136-
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
137-
else:
138-
tokenizer = Mamba2TokenizerFast.from_pretrained("state-spaces/mamba-130m-hf")
139-
tokenizer.save_pretrained(output_dir)
154+
mamba2_model_dict["load_and_save_tokenizer"](output_dir=output_dir, tokenizer_model_path=tokenizer_model_path)
140155

141156

142157
if __name__ == "__main__":
@@ -148,13 +163,24 @@ def convert_mamba2_checkpoint_file_to_huggingface_model_file(
148163
required=True,
149164
help="Path to a directory containing the `pytorch_model.bin` or `.safetensors` mamba2_ssm checkpoint file to be converted.",
150165
)
166+
parser.add_argument(
167+
"-m",
168+
"--mamba2_model_type",
169+
type=str,
170+
default="mamba_ssm",
171+
const="mamba_ssm",
172+
required=True,
173+
choices=("codestral", "mamba_ssm"),
174+
help="The model type the conversion will be performed on. Can choose from either `codestral` or `mamba_ssm`.",
175+
)
151176
parser.add_argument(
152177
"-p",
153178
"--precision",
154179
type=str,
155-
default="bf16",
180+
default="fp16",
181+
const="fp16",
156182
required=True,
157-
choices=["fp32", "fp16", "bf16"],
183+
choices=("fp32", "fp16", "bf16"),
158184
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
159185
)
160186
parser.add_argument(
@@ -171,5 +197,9 @@ def convert_mamba2_checkpoint_file_to_huggingface_model_file(
171197
args = parser.parse_args()
172198

173199
convert_mamba2_checkpoint_file_to_huggingface_model_file(
174-
args.mamba2_checkpoint_file, args.precision, args.output_dir, args.tokenizer_model_path
200+
args.mamba2_checkpoint_directory,
201+
args.mamba2_model_type,
202+
args.precision,
203+
args.output_dir,
204+
args.tokenizer_model_path,
175205
)

0 commit comments

Comments
 (0)