|
15 | 15 | """This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed.""" |
16 | 16 |
|
17 | 17 | import argparse |
| 18 | +import json |
| 19 | +from functools import partial |
| 20 | +from os import path |
| 21 | +from typing import Dict, Optional |
18 | 22 |
|
19 | 23 | import torch |
20 | 24 | from safetensors import safe_open |
| 25 | +from safetensors.torch import save_model |
21 | 26 |
|
22 | | -from transformers import LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM |
| 27 | +from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM |
23 | 28 |
|
24 | 29 |
|
25 | | -def convert_mamba2_checkpoint_file_to_huggingface_model_file( |
26 | | - mamba2_checkpoint_path: str, tokenizer_model_path: str, output_dir: str |
27 | | -) -> None: |
28 | | - hf_config = Mamba2Config() |
29 | | - hf_model = Mamba2ForCausalLM(hf_config) |
| 30 | +def load_state_dict_from_safetensors(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]: |
30 | 31 | # Load weights and config from paths |
31 | 32 | original_state_dict = {} |
32 | | - with safe_open(mamba2_checkpoint_path, framework="pt") as f: |
| 33 | + with safe_open(path.join(mamba2_checkpoint_path, ckpt_name), framework="pt") as f: |
33 | 34 | for k in f.keys(): |
34 | 35 | newk = k.removeprefix("model.") |
35 | 36 | original_state_dict[newk] = f.get_tensor(k).clone() |
| 37 | + return original_state_dict |
| 38 | + |
| 39 | + |
| 40 | +def load_state_dict_from_torch(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]: |
| 41 | + return torch.load(path.join(mamba2_checkpoint_path, ckpt_name), map_location="cpu") |
| 42 | + |
| 43 | + |
| 44 | +def convert_ssm_config_to_hf_config(config_ssm: Dict, mamba2_model_dict: Dict) -> Mamba2Config: |
| 45 | + """Convert a Mamba2Config from mamba_ssm to a Mamba2Config from here.""" |
| 46 | + hf_config = Mamba2Config() |
| 47 | + |
| 48 | + # Switch to a different dict depending on model type |
| 49 | + config_dict = mamba2_model_dict |
| 50 | + |
| 51 | + # Set important values from config and recalculate other resulting entries |
| 52 | + hf_config.hidden_size = config_ssm[config_dict["hidden_size"]] |
| 53 | + hf_config.num_heads = (hf_config.hidden_size * hf_config.expand) // hf_config.head_dim |
| 54 | + hf_config.num_hidden_layers = config_ssm[config_dict["num_hidden_layers"]] |
| 55 | + hf_config.n_groups = config_ssm.get(config_dict["n_groups"], 1) |
| 56 | + hf_config.tie_word_embeddings = config_ssm["tie_embeddings"] |
| 57 | + hf_config.bos_token_id = config_dict["bos_token_id"] |
| 58 | + hf_config.pad_token_id = config_dict["pad_token_id"] |
| 59 | + hf_config.eos_token_id = config_dict["eos_token_id"] |
| 60 | + |
| 61 | + # Padded vocab size, mostly of 16 but 32 is also very common in different models |
| 62 | + vocab_size = config_ssm["vocab_size"] |
| 63 | + pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"] |
| 64 | + if (vocab_size % pad_vocab_size_multiple) != 0: |
| 65 | + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) |
| 66 | + hf_config.vocab_size = vocab_size |
| 67 | + |
| 68 | + return hf_config |
| 69 | + |
| 70 | + |
| 71 | +def load_and_save_tokenizer( |
| 72 | + mamba2_model_type: str, |
| 73 | + output_dir: str, |
| 74 | + tokenizer_model_path: Optional[str] = None, |
| 75 | +) -> None: |
| 76 | + tokenizer = None |
| 77 | + |
| 78 | + # Load tokenizer |
| 79 | + if tokenizer_model_path is not None and mamba2_model_type == "codestral": |
| 80 | + tokenizer_class = LlamaTokenizerFast |
| 81 | + tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True) |
| 82 | + elif mamba2_model_type == "mamba_ssm": |
| 83 | + tokenizer = GPTNeoXTokenizerFast.from_pretrained("state-spaces/mamba-130m-hf", padding_side="left") |
| 84 | + |
| 85 | + # Save tokenizer |
| 86 | + if tokenizer is not None: |
| 87 | + tokenizer.save_pretrained(output_dir) |
36 | 88 |
|
| 89 | + |
| 90 | +_MAMBA2_MODELS_DICT = { |
| 91 | + "codestral": { |
| 92 | + "hidden_size": "dim", |
| 93 | + "num_hidden_layers": "n_layers", |
| 94 | + "n_groups": "n_groups", |
| 95 | + "bos_token_id": 0, |
| 96 | + "pad_token_id": 1, |
| 97 | + "eos_token_id": 2, |
| 98 | + "config_name": "params.json", |
| 99 | + "load_state_dict": partial(load_state_dict_from_safetensors, ckpt_name="consolidated.safetensors"), |
| 100 | + "load_and_save_tokenizer": partial(load_and_save_tokenizer, "codestral"), |
| 101 | + }, |
| 102 | + "mamba_ssm": { |
| 103 | + "hidden_size": "d_model", |
| 104 | + "num_hidden_layers": "n_layer", |
| 105 | + "n_groups": "ngroups", |
| 106 | + "bos_token_id": 0, |
| 107 | + "pad_token_id": 0, |
| 108 | + "eos_token_id": 0, |
| 109 | + "config_name": "config.json", |
| 110 | + "load_state_dict": partial(load_state_dict_from_torch, ckpt_name="pytorch_model.bin"), |
| 111 | + "load_and_save_tokenizer": partial(load_and_save_tokenizer, "mamba_ssm"), |
| 112 | + }, |
| 113 | +} |
| 114 | + |
| 115 | + |
| 116 | +def convert_mamba2_checkpoint_file_to_huggingface_model_file( |
| 117 | + mamba2_checkpoint_path: str, |
| 118 | + mamba2_model_type: str, |
| 119 | + precision: str, |
| 120 | + output_dir: str, |
| 121 | + tokenizer_model_path: Optional[str] = None, |
| 122 | +) -> None: |
| 123 | + mamba2_model_dict = _MAMBA2_MODELS_DICT[mamba2_model_type] |
| 124 | + |
| 125 | + # Load and save config based on name |
| 126 | + config_path = path.join(mamba2_checkpoint_path, mamba2_model_dict["config_name"]) |
| 127 | + with open(config_path, "r", encoding="utf-8") as json_file: |
| 128 | + config = json.load(json_file) |
| 129 | + hf_config = convert_ssm_config_to_hf_config(config_ssm=config, mamba2_model_dict=mamba2_model_dict) |
| 130 | + hf_config.save_pretrained(output_dir) |
| 131 | + |
| 132 | + # Load state dict of the original model and transfer to hf model |
| 133 | + original_state_dict = mamba2_model_dict["load_state_dict"](mamba2_checkpoint_path=mamba2_checkpoint_path) |
| 134 | + hf_model = Mamba2ForCausalLM(hf_config) |
37 | 135 | hf_model.load_state_dict(original_state_dict) |
38 | 136 |
|
39 | 137 | # Save new model to pytorch_dump_path |
40 | | - hf_model.to(torch.bfloat16).save_pretrained(output_dir) |
41 | | - tokenizer_class = LlamaTokenizerFast |
42 | | - tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True) |
43 | | - tokenizer.save_pretrained(output_dir) |
| 138 | + dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16) |
| 139 | + save_model(hf_model.to(dtype), path.join(output_dir, "model.safetensors"), metadata={"format": "pt"}) |
| 140 | + |
| 141 | + # Load and save tokenizer |
| 142 | + mamba2_model_dict["load_and_save_tokenizer"](output_dir=output_dir, tokenizer_model_path=tokenizer_model_path) |
44 | 143 |
|
45 | 144 |
|
46 | 145 | if __name__ == "__main__": |
47 | 146 | parser = argparse.ArgumentParser() |
48 | 147 | parser.add_argument( |
49 | 148 | "-i", |
50 | | - "--mamba2_checkpoint_file", |
| 149 | + "--mamba2_checkpoint_directory", |
51 | 150 | type=str, |
52 | 151 | required=True, |
53 | | - help="Path to a `pytorch_model.bin` mamba2_ssm checkpoint file to be converted.", |
| 152 | + help="Path to a directory containing the `pytorch_model.bin` or `.safetensors` mamba2_ssm checkpoint file to be converted.", |
54 | 153 | ) |
55 | 154 | parser.add_argument( |
56 | | - "-c", |
57 | | - "--tokenizer_model_path", |
| 155 | + "-m", |
| 156 | + "--mamba2_model_type", |
| 157 | + type=str, |
| 158 | + default="mamba_ssm", |
| 159 | + const="mamba_ssm", |
| 160 | + required=True, |
| 161 | + choices=("codestral", "mamba_ssm"), |
| 162 | + help="The model type the conversion will be performed on. Can choose from either `codestral` or `mamba_ssm`.", |
| 163 | + ) |
| 164 | + parser.add_argument( |
| 165 | + "-p", |
| 166 | + "--precision", |
58 | 167 | type=str, |
| 168 | + default="fp16", |
| 169 | + const="fp16", |
59 | 170 | required=True, |
60 | | - help="Path to a `config.json` file corresponding to a Mamba2Config of the original mamba2_ssm model.", |
| 171 | + choices=("fp32", "fp16", "bf16"), |
| 172 | + help="The precision the model will be saved in. Select from fp32, fp16 or bf16.", |
61 | 173 | ) |
62 | 174 | parser.add_argument( |
63 | 175 | "-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to." |
64 | 176 | ) |
| 177 | + parser.add_argument( |
| 178 | + "-t", |
| 179 | + "--tokenizer_model_path", |
| 180 | + type=str, |
| 181 | + default=None, |
| 182 | + required=False, |
| 183 | + help="Path to a `codestral` tokenizer file.", |
| 184 | + ) |
65 | 185 | args = parser.parse_args() |
66 | 186 |
|
67 | 187 | convert_mamba2_checkpoint_file_to_huggingface_model_file( |
68 | | - args.mamba2_checkpoint_file, args.tokenizer_model_path, args.output_dir |
| 188 | + args.mamba2_checkpoint_directory, |
| 189 | + args.mamba2_model_type, |
| 190 | + args.precision, |
| 191 | + args.output_dir, |
| 192 | + args.tokenizer_model_path, |
69 | 193 | ) |
0 commit comments