1616
1717import argparse
1818import json
19+ from functools import partial
1920from os import path
20- from typing import Dict , Optional , Tuple
21+ from typing import Dict , Optional
2122
2223import torch
2324from safetensors import safe_open
2627from transformers import LlamaTokenizerFast , Mamba2Config , Mamba2ForCausalLM , Mamba2TokenizerFast
2728
2829
29- _MAMBA2_MODELS_DICT = {
30- "codestral" : {
31- "hidden_size" : "dim" ,
32- "num_hidden_layers" : "n_layers" ,
33- "n_groups" : "n_groups" ,
34- "residual_in_fp32" : "residual_in_fp32" ,
35- "tie_word_embeddings" : "tie_embeddings" ,
36- "norm_before_gate" : False ,
37- "vocab_size" : "vocab_size" ,
38- "pad_vocab_size_multiple" : "pad_vocab_size_multiple" ,
39- "bos_token_id" : 0 ,
40- "pad_token_id" : 1 ,
41- "eos_token_id" : 2 ,
42- },
43- "base" : {
44- "hidden_size" : "d_model" ,
45- "num_hidden_layers" : "n_layer" ,
46- "n_groups" : "ngroups" ,
47- "residual_in_fp32" : "residual_in_fp32" ,
48- "tie_word_embeddings" : "tie_embeddings" ,
49- "norm_before_gate" : False ,
50- "vocab_size" : "vocab_size" ,
51- "pad_vocab_size_multiple" : "pad_vocab_size_multiple" ,
52- "bos_token_id" : 0 ,
53- "pad_token_id" : 0 ,
54- "eos_token_id" : 0 ,
55- },
56- }
30+ def load_state_dict_from_safetensors (mamba2_checkpoint_path : str , ckpt_name : str ) -> Dict [str , torch .Tensor ]:
31+ # Load weights and config from paths
32+ original_state_dict = {}
33+ with safe_open (path .join (mamba2_checkpoint_path , ckpt_name ), framework = "pt" ) as f :
34+ for k in f .keys ():
35+ newk = k .removeprefix ("model." )
36+ original_state_dict [newk ] = f .get_tensor (k ).clone ()
37+ return original_state_dict
5738
5839
59- def convert_ssm_config_to_hf_config (config_ssm : Dict ) -> Tuple [Mamba2Config , bool ]:
40+ def load_state_dict_from_torch (mamba2_checkpoint_path : str , ckpt_name : str ) -> Dict [str , torch .Tensor ]:
41+ return torch .load (path .join (mamba2_checkpoint_path , ckpt_name ), map_location = "cpu" )
42+
43+
44+ def convert_ssm_config_to_hf_config (config_ssm : Dict , mamba2_model_dict : Dict ) -> Mamba2Config :
6045 """Convert a Mamba2Config from mamba_ssm to a Mamba2Config from here."""
6146 hf_config = Mamba2Config ()
6247
63- # Flag for codestral model
64- is_not_codestral = "dim" not in config_ssm
65-
6648 # Switch to a different dict depending on model type
67- config_key = "base" if is_not_codestral else "codestral"
68- config_dict = _MAMBA2_MODELS_DICT [config_key ]
49+ config_dict = mamba2_model_dict
6950
7051 # Set important values from config and recalculate other resulting entries
7152 hf_config .hidden_size = config_ssm [config_dict ["hidden_size" ]]
@@ -86,43 +67,82 @@ def convert_ssm_config_to_hf_config(config_ssm: Dict) -> Tuple[Mamba2Config, boo
8667 vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple )
8768 hf_config .vocab_size = vocab_size
8869
89- return hf_config , is_not_codestral
70+ return hf_config
9071
9172
92- def load_state_dict_from_safetensors (mamba2_checkpoint_path : str ) -> Dict [str , torch .Tensor ]:
93- # Load weights and config from paths
94- original_state_dict = {}
95- with safe_open (path .join (mamba2_checkpoint_path , "consolidated.safetensors" ), framework = "pt" ) as f :
96- for k in f .keys ():
97- newk = k .removeprefix ("model." )
98- original_state_dict [newk ] = f .get_tensor (k ).clone ()
99- return original_state_dict
73+ def load_and_save_tokenizer (
74+ mamba2_model_type : str ,
75+ output_dir : str ,
76+ tokenizer_model_path : Optional [str ] = None ,
77+ ) -> None :
78+ tokenizer = None
10079
80+ # Load tokenizer
81+ if tokenizer_model_path is not None and mamba2_model_type == "codestral" :
82+ tokenizer_class = LlamaTokenizerFast
83+ tokenizer = tokenizer_class (tokenizer_model_path , legacy = False , from_slow = True )
84+ elif mamba2_model_type == "mamba_ssm" :
85+ tokenizer = Mamba2TokenizerFast .from_pretrained ("state-spaces/mamba-130m-hf" )
10186
102- def load_state_dict_from_torch (mamba2_checkpoint_path : str ) -> Dict [str , torch .Tensor ]:
103- return torch .load (path .join (mamba2_checkpoint_path , "pytorch_model.bin" ), map_location = "cpu" )
87+ # Save tokenizer
88+ if tokenizer is not None :
89+ tokenizer .save_pretrained (output_dir )
90+
91+
92+ _MAMBA2_MODELS_DICT = {
93+ "codestral" : {
94+ "hidden_size" : "dim" ,
95+ "num_hidden_layers" : "n_layers" ,
96+ "n_groups" : "n_groups" ,
97+ "residual_in_fp32" : "residual_in_fp32" ,
98+ "tie_word_embeddings" : "tie_embeddings" ,
99+ "norm_before_gate" : False ,
100+ "vocab_size" : "vocab_size" ,
101+ "pad_vocab_size_multiple" : "pad_vocab_size_multiple" ,
102+ "bos_token_id" : 0 ,
103+ "pad_token_id" : 1 ,
104+ "eos_token_id" : 2 ,
105+ "config_name" : "params.json" ,
106+ "load_state_dict" : partial (load_state_dict_from_safetensors , ckpt_name = "consolidated.safetensors" ),
107+ "load_and_save_tokenizer" : partial (load_and_save_tokenizer , "codestral" ),
108+ },
109+ "mamba_ssm" : {
110+ "hidden_size" : "d_model" ,
111+ "num_hidden_layers" : "n_layer" ,
112+ "n_groups" : "ngroups" ,
113+ "residual_in_fp32" : "residual_in_fp32" ,
114+ "tie_word_embeddings" : "tie_embeddings" ,
115+ "norm_before_gate" : False ,
116+ "vocab_size" : "vocab_size" ,
117+ "pad_vocab_size_multiple" : "pad_vocab_size_multiple" ,
118+ "bos_token_id" : 0 ,
119+ "pad_token_id" : 0 ,
120+ "eos_token_id" : 0 ,
121+ "config_name" : "config.json" ,
122+ "load_state_dict" : partial (load_state_dict_from_torch , ckpt_name = "pytorch_model.bin" ),
123+ "load_and_save_tokenizer" : partial (load_and_save_tokenizer , "mamba_ssm" ),
124+ },
125+ }
104126
105127
106128def convert_mamba2_checkpoint_file_to_huggingface_model_file (
107- mamba2_checkpoint_path : str , precision : str , output_dir : str , tokenizer_model_path : Optional [str ] = None
129+ mamba2_checkpoint_path : str ,
130+ mamba2_model_type : str ,
131+ precision : str ,
132+ output_dir : str ,
133+ tokenizer_model_path : Optional [str ] = None ,
108134) -> None :
135+ mamba2_model_dict = _MAMBA2_MODELS_DICT [mamba2_model_type ]
136+
109137 # Load and save config based on name
110- config_path = mamba2_checkpoint_path
111- config_path = (
112- path .join (config_path , "params.json" )
113- if path .isfile (path .join (config_path , "params.json" ))
114- else path .join (config_path , "config.json" )
115- )
138+ config_path = path .join (mamba2_checkpoint_path , mamba2_model_dict ["config_name" ])
116139 with open (config_path , "r" , encoding = "utf-8" ) as json_file :
117140 config = json .load (json_file )
118- hf_config , is_not_codestral = convert_ssm_config_to_hf_config (config )
141+ hf_config = convert_ssm_config_to_hf_config (config_ssm = config , mamba2_model_dict = mamba2_model_dict )
119142 hf_config .save_pretrained (output_dir )
120143
121- # Load state dict of the original model
122- state_dict_load_function = load_state_dict_from_torch if is_not_codestral else load_state_dict_from_safetensors
123- original_state_dict = state_dict_load_function (mamba2_checkpoint_path )
124-
125- # Load and transfer to hf model
144+ # Load state dict of the original model and transfer to hf model
145+ original_state_dict = mamba2_model_dict ["load_state_dict" ](mamba2_checkpoint_path = mamba2_checkpoint_path )
126146 hf_model = Mamba2ForCausalLM (hf_config )
127147 hf_model .load_state_dict (original_state_dict )
128148
@@ -131,12 +151,7 @@ def convert_mamba2_checkpoint_file_to_huggingface_model_file(
131151 save_model (hf_model .to (dtype ), path .join (output_dir , "model.safetensors" ), metadata = {"format" : "pt" })
132152
133153 # Load and save tokenizer
134- if tokenizer_model_path is not None and not is_not_codestral :
135- tokenizer_class = LlamaTokenizerFast
136- tokenizer = tokenizer_class (tokenizer_model_path , legacy = False , from_slow = True )
137- else :
138- tokenizer = Mamba2TokenizerFast .from_pretrained ("state-spaces/mamba-130m-hf" )
139- tokenizer .save_pretrained (output_dir )
154+ mamba2_model_dict ["load_and_save_tokenizer" ](output_dir = output_dir , tokenizer_model_path = tokenizer_model_path )
140155
141156
142157if __name__ == "__main__" :
@@ -148,13 +163,24 @@ def convert_mamba2_checkpoint_file_to_huggingface_model_file(
148163 required = True ,
149164 help = "Path to a directory containing the `pytorch_model.bin` or `.safetensors` mamba2_ssm checkpoint file to be converted." ,
150165 )
166+ parser .add_argument (
167+ "-m" ,
168+ "--mamba2_model_type" ,
169+ type = str ,
170+ default = "mamba_ssm" ,
171+ const = "mamba_ssm" ,
172+ required = True ,
173+ choices = ("codestral" , "mamba_ssm" ),
174+ help = "The model type the conversion will be performed on. Can choose from either `codestral` or `mamba_ssm`." ,
175+ )
151176 parser .add_argument (
152177 "-p" ,
153178 "--precision" ,
154179 type = str ,
155- default = "bf16" ,
180+ default = "fp16" ,
181+ const = "fp16" ,
156182 required = True ,
157- choices = [ "fp32" , "fp16" , "bf16" ] ,
183+ choices = ( "fp32" , "fp16" , "bf16" ) ,
158184 help = "The precision the model will be saved in. Select from fp32, fp16 or bf16." ,
159185 )
160186 parser .add_argument (
@@ -171,5 +197,9 @@ def convert_mamba2_checkpoint_file_to_huggingface_model_file(
171197 args = parser .parse_args ()
172198
173199 convert_mamba2_checkpoint_file_to_huggingface_model_file (
174- args .mamba2_checkpoint_file , args .precision , args .output_dir , args .tokenizer_model_path
200+ args .mamba2_checkpoint_directory ,
201+ args .mamba2_model_type ,
202+ args .precision ,
203+ args .output_dir ,
204+ args .tokenizer_model_path ,
175205 )
0 commit comments