Skip to content

Commit 92a75ff

Browse files
authored
Mamba2 conversion script for original models (#32580)
* first attempt at allowing both conversions from codestral and from the original mamba ssm * allow fp16, seems default for mamba2 * dtype fix * simplify codestral check, dont overwrite pad/eos/bos when codestral * change file -> directory * use path join to be safe * style * apply code review - add util mamba2 tokenizer (gptneox with left padding) - add models dict * fix copies * add tokenizer to docs * empty commit to check for weird err * make conversion user dependent on model type, defaults for original paper models * small comment nit * remove norm_before_gate in conversion * simplify model dict by using shared keys directly + remove unnecessary attributes * fix tokenization: remove separate mamba2 tokenizer, add padding option as kwarg to gptneox one and reuse it for the conversion script * simplify even further as we pass padding side via **kwargs already
1 parent 39bfb2f commit 92a75ff

File tree

1 file changed

+141
-17
lines changed

1 file changed

+141
-17
lines changed

src/transformers/models/mamba2/convert_mamba2_ssm_checkpoint_to_pytorch.py

Lines changed: 141 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,55 +15,179 @@
1515
"""This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
1616

1717
import argparse
18+
import json
19+
from functools import partial
20+
from os import path
21+
from typing import Dict, Optional
1822

1923
import torch
2024
from safetensors import safe_open
25+
from safetensors.torch import save_model
2126

22-
from transformers import LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM
27+
from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM
2328

2429

25-
def convert_mamba2_checkpoint_file_to_huggingface_model_file(
26-
mamba2_checkpoint_path: str, tokenizer_model_path: str, output_dir: str
27-
) -> None:
28-
hf_config = Mamba2Config()
29-
hf_model = Mamba2ForCausalLM(hf_config)
30+
def load_state_dict_from_safetensors(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
3031
# Load weights and config from paths
3132
original_state_dict = {}
32-
with safe_open(mamba2_checkpoint_path, framework="pt") as f:
33+
with safe_open(path.join(mamba2_checkpoint_path, ckpt_name), framework="pt") as f:
3334
for k in f.keys():
3435
newk = k.removeprefix("model.")
3536
original_state_dict[newk] = f.get_tensor(k).clone()
37+
return original_state_dict
38+
39+
40+
def load_state_dict_from_torch(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
41+
return torch.load(path.join(mamba2_checkpoint_path, ckpt_name), map_location="cpu")
42+
43+
44+
def convert_ssm_config_to_hf_config(config_ssm: Dict, mamba2_model_dict: Dict) -> Mamba2Config:
45+
"""Convert a Mamba2Config from mamba_ssm to a Mamba2Config from here."""
46+
hf_config = Mamba2Config()
47+
48+
# Switch to a different dict depending on model type
49+
config_dict = mamba2_model_dict
50+
51+
# Set important values from config and recalculate other resulting entries
52+
hf_config.hidden_size = config_ssm[config_dict["hidden_size"]]
53+
hf_config.num_heads = (hf_config.hidden_size * hf_config.expand) // hf_config.head_dim
54+
hf_config.num_hidden_layers = config_ssm[config_dict["num_hidden_layers"]]
55+
hf_config.n_groups = config_ssm.get(config_dict["n_groups"], 1)
56+
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
57+
hf_config.bos_token_id = config_dict["bos_token_id"]
58+
hf_config.pad_token_id = config_dict["pad_token_id"]
59+
hf_config.eos_token_id = config_dict["eos_token_id"]
60+
61+
# Padded vocab size, mostly of 16 but 32 is also very common in different models
62+
vocab_size = config_ssm["vocab_size"]
63+
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
64+
if (vocab_size % pad_vocab_size_multiple) != 0:
65+
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
66+
hf_config.vocab_size = vocab_size
67+
68+
return hf_config
69+
70+
71+
def load_and_save_tokenizer(
72+
mamba2_model_type: str,
73+
output_dir: str,
74+
tokenizer_model_path: Optional[str] = None,
75+
) -> None:
76+
tokenizer = None
77+
78+
# Load tokenizer
79+
if tokenizer_model_path is not None and mamba2_model_type == "codestral":
80+
tokenizer_class = LlamaTokenizerFast
81+
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
82+
elif mamba2_model_type == "mamba_ssm":
83+
tokenizer = GPTNeoXTokenizerFast.from_pretrained("state-spaces/mamba-130m-hf", padding_side="left")
84+
85+
# Save tokenizer
86+
if tokenizer is not None:
87+
tokenizer.save_pretrained(output_dir)
3688

89+
90+
_MAMBA2_MODELS_DICT = {
91+
"codestral": {
92+
"hidden_size": "dim",
93+
"num_hidden_layers": "n_layers",
94+
"n_groups": "n_groups",
95+
"bos_token_id": 0,
96+
"pad_token_id": 1,
97+
"eos_token_id": 2,
98+
"config_name": "params.json",
99+
"load_state_dict": partial(load_state_dict_from_safetensors, ckpt_name="consolidated.safetensors"),
100+
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "codestral"),
101+
},
102+
"mamba_ssm": {
103+
"hidden_size": "d_model",
104+
"num_hidden_layers": "n_layer",
105+
"n_groups": "ngroups",
106+
"bos_token_id": 0,
107+
"pad_token_id": 0,
108+
"eos_token_id": 0,
109+
"config_name": "config.json",
110+
"load_state_dict": partial(load_state_dict_from_torch, ckpt_name="pytorch_model.bin"),
111+
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "mamba_ssm"),
112+
},
113+
}
114+
115+
116+
def convert_mamba2_checkpoint_file_to_huggingface_model_file(
117+
mamba2_checkpoint_path: str,
118+
mamba2_model_type: str,
119+
precision: str,
120+
output_dir: str,
121+
tokenizer_model_path: Optional[str] = None,
122+
) -> None:
123+
mamba2_model_dict = _MAMBA2_MODELS_DICT[mamba2_model_type]
124+
125+
# Load and save config based on name
126+
config_path = path.join(mamba2_checkpoint_path, mamba2_model_dict["config_name"])
127+
with open(config_path, "r", encoding="utf-8") as json_file:
128+
config = json.load(json_file)
129+
hf_config = convert_ssm_config_to_hf_config(config_ssm=config, mamba2_model_dict=mamba2_model_dict)
130+
hf_config.save_pretrained(output_dir)
131+
132+
# Load state dict of the original model and transfer to hf model
133+
original_state_dict = mamba2_model_dict["load_state_dict"](mamba2_checkpoint_path=mamba2_checkpoint_path)
134+
hf_model = Mamba2ForCausalLM(hf_config)
37135
hf_model.load_state_dict(original_state_dict)
38136

39137
# Save new model to pytorch_dump_path
40-
hf_model.to(torch.bfloat16).save_pretrained(output_dir)
41-
tokenizer_class = LlamaTokenizerFast
42-
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
43-
tokenizer.save_pretrained(output_dir)
138+
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
139+
save_model(hf_model.to(dtype), path.join(output_dir, "model.safetensors"), metadata={"format": "pt"})
140+
141+
# Load and save tokenizer
142+
mamba2_model_dict["load_and_save_tokenizer"](output_dir=output_dir, tokenizer_model_path=tokenizer_model_path)
44143

45144

46145
if __name__ == "__main__":
47146
parser = argparse.ArgumentParser()
48147
parser.add_argument(
49148
"-i",
50-
"--mamba2_checkpoint_file",
149+
"--mamba2_checkpoint_directory",
51150
type=str,
52151
required=True,
53-
help="Path to a `pytorch_model.bin` mamba2_ssm checkpoint file to be converted.",
152+
help="Path to a directory containing the `pytorch_model.bin` or `.safetensors` mamba2_ssm checkpoint file to be converted.",
54153
)
55154
parser.add_argument(
56-
"-c",
57-
"--tokenizer_model_path",
155+
"-m",
156+
"--mamba2_model_type",
157+
type=str,
158+
default="mamba_ssm",
159+
const="mamba_ssm",
160+
required=True,
161+
choices=("codestral", "mamba_ssm"),
162+
help="The model type the conversion will be performed on. Can choose from either `codestral` or `mamba_ssm`.",
163+
)
164+
parser.add_argument(
165+
"-p",
166+
"--precision",
58167
type=str,
168+
default="fp16",
169+
const="fp16",
59170
required=True,
60-
help="Path to a `config.json` file corresponding to a Mamba2Config of the original mamba2_ssm model.",
171+
choices=("fp32", "fp16", "bf16"),
172+
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
61173
)
62174
parser.add_argument(
63175
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
64176
)
177+
parser.add_argument(
178+
"-t",
179+
"--tokenizer_model_path",
180+
type=str,
181+
default=None,
182+
required=False,
183+
help="Path to a `codestral` tokenizer file.",
184+
)
65185
args = parser.parse_args()
66186

67187
convert_mamba2_checkpoint_file_to_huggingface_model_file(
68-
args.mamba2_checkpoint_file, args.tokenizer_model_path, args.output_dir
188+
args.mamba2_checkpoint_directory,
189+
args.mamba2_model_type,
190+
args.precision,
191+
args.output_dir,
192+
args.tokenizer_model_path,
69193
)

0 commit comments

Comments
 (0)