Skip to content

Commit

Permalink
Mamba2 conversion script for original models (huggingface#32580)
Browse files Browse the repository at this point in the history
* first attempt at allowing both conversions from codestral and from the original mamba ssm

* allow fp16, seems default for mamba2

* dtype fix

* simplify codestral check, dont overwrite pad/eos/bos when codestral

* change file -> directory

* use path join to be safe

* style

* apply code review
- add util mamba2 tokenizer (gptneox with left padding)
- add models dict

* fix copies

* add tokenizer to docs

* empty commit to check for weird err

* make conversion user dependent on model type, defaults for original paper models

* small comment nit

* remove norm_before_gate in conversion

* simplify model dict by using shared keys directly + remove unnecessary attributes

* fix tokenization: remove separate mamba2 tokenizer, add padding option as kwarg to gptneox one and reuse it for the conversion script

* simplify even further as we pass padding side via **kwargs already
  • Loading branch information
vasqu authored and zucchini-nlp committed Aug 30, 2024
1 parent 32b79fb commit 8a04372
Showing 1 changed file with 141 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,179 @@
"""This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""

import argparse
import json
from functools import partial
from os import path
from typing import Dict, Optional

import torch
from safetensors import safe_open
from safetensors.torch import save_model

from transformers import LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM
from transformers import GPTNeoXTokenizerFast, LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM


def convert_mamba2_checkpoint_file_to_huggingface_model_file(
mamba2_checkpoint_path: str, tokenizer_model_path: str, output_dir: str
) -> None:
hf_config = Mamba2Config()
hf_model = Mamba2ForCausalLM(hf_config)
def load_state_dict_from_safetensors(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
# Load weights and config from paths
original_state_dict = {}
with safe_open(mamba2_checkpoint_path, framework="pt") as f:
with safe_open(path.join(mamba2_checkpoint_path, ckpt_name), framework="pt") as f:
for k in f.keys():
newk = k.removeprefix("model.")
original_state_dict[newk] = f.get_tensor(k).clone()
return original_state_dict


def load_state_dict_from_torch(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
return torch.load(path.join(mamba2_checkpoint_path, ckpt_name), map_location="cpu")


def convert_ssm_config_to_hf_config(config_ssm: Dict, mamba2_model_dict: Dict) -> Mamba2Config:
"""Convert a Mamba2Config from mamba_ssm to a Mamba2Config from here."""
hf_config = Mamba2Config()

# Switch to a different dict depending on model type
config_dict = mamba2_model_dict

# Set important values from config and recalculate other resulting entries
hf_config.hidden_size = config_ssm[config_dict["hidden_size"]]
hf_config.num_heads = (hf_config.hidden_size * hf_config.expand) // hf_config.head_dim
hf_config.num_hidden_layers = config_ssm[config_dict["num_hidden_layers"]]
hf_config.n_groups = config_ssm.get(config_dict["n_groups"], 1)
hf_config.tie_word_embeddings = config_ssm["tie_embeddings"]
hf_config.bos_token_id = config_dict["bos_token_id"]
hf_config.pad_token_id = config_dict["pad_token_id"]
hf_config.eos_token_id = config_dict["eos_token_id"]

# Padded vocab size, mostly of 16 but 32 is also very common in different models
vocab_size = config_ssm["vocab_size"]
pad_vocab_size_multiple = config_ssm["pad_vocab_size_multiple"]
if (vocab_size % pad_vocab_size_multiple) != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
hf_config.vocab_size = vocab_size

return hf_config


def load_and_save_tokenizer(
mamba2_model_type: str,
output_dir: str,
tokenizer_model_path: Optional[str] = None,
) -> None:
tokenizer = None

# Load tokenizer
if tokenizer_model_path is not None and mamba2_model_type == "codestral":
tokenizer_class = LlamaTokenizerFast
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
elif mamba2_model_type == "mamba_ssm":
tokenizer = GPTNeoXTokenizerFast.from_pretrained("state-spaces/mamba-130m-hf", padding_side="left")

# Save tokenizer
if tokenizer is not None:
tokenizer.save_pretrained(output_dir)


_MAMBA2_MODELS_DICT = {
"codestral": {
"hidden_size": "dim",
"num_hidden_layers": "n_layers",
"n_groups": "n_groups",
"bos_token_id": 0,
"pad_token_id": 1,
"eos_token_id": 2,
"config_name": "params.json",
"load_state_dict": partial(load_state_dict_from_safetensors, ckpt_name="consolidated.safetensors"),
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "codestral"),
},
"mamba_ssm": {
"hidden_size": "d_model",
"num_hidden_layers": "n_layer",
"n_groups": "ngroups",
"bos_token_id": 0,
"pad_token_id": 0,
"eos_token_id": 0,
"config_name": "config.json",
"load_state_dict": partial(load_state_dict_from_torch, ckpt_name="pytorch_model.bin"),
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "mamba_ssm"),
},
}


def convert_mamba2_checkpoint_file_to_huggingface_model_file(
mamba2_checkpoint_path: str,
mamba2_model_type: str,
precision: str,
output_dir: str,
tokenizer_model_path: Optional[str] = None,
) -> None:
mamba2_model_dict = _MAMBA2_MODELS_DICT[mamba2_model_type]

# Load and save config based on name
config_path = path.join(mamba2_checkpoint_path, mamba2_model_dict["config_name"])
with open(config_path, "r", encoding="utf-8") as json_file:
config = json.load(json_file)
hf_config = convert_ssm_config_to_hf_config(config_ssm=config, mamba2_model_dict=mamba2_model_dict)
hf_config.save_pretrained(output_dir)

# Load state dict of the original model and transfer to hf model
original_state_dict = mamba2_model_dict["load_state_dict"](mamba2_checkpoint_path=mamba2_checkpoint_path)
hf_model = Mamba2ForCausalLM(hf_config)
hf_model.load_state_dict(original_state_dict)

# Save new model to pytorch_dump_path
hf_model.to(torch.bfloat16).save_pretrained(output_dir)
tokenizer_class = LlamaTokenizerFast
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
tokenizer.save_pretrained(output_dir)
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
save_model(hf_model.to(dtype), path.join(output_dir, "model.safetensors"), metadata={"format": "pt"})

# Load and save tokenizer
mamba2_model_dict["load_and_save_tokenizer"](output_dir=output_dir, tokenizer_model_path=tokenizer_model_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--mamba2_checkpoint_file",
"--mamba2_checkpoint_directory",
type=str,
required=True,
help="Path to a `pytorch_model.bin` mamba2_ssm checkpoint file to be converted.",
help="Path to a directory containing the `pytorch_model.bin` or `.safetensors` mamba2_ssm checkpoint file to be converted.",
)
parser.add_argument(
"-c",
"--tokenizer_model_path",
"-m",
"--mamba2_model_type",
type=str,
default="mamba_ssm",
const="mamba_ssm",
required=True,
choices=("codestral", "mamba_ssm"),
help="The model type the conversion will be performed on. Can choose from either `codestral` or `mamba_ssm`.",
)
parser.add_argument(
"-p",
"--precision",
type=str,
default="fp16",
const="fp16",
required=True,
help="Path to a `config.json` file corresponding to a Mamba2Config of the original mamba2_ssm model.",
choices=("fp32", "fp16", "bf16"),
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
)
parser.add_argument(
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
)
parser.add_argument(
"-t",
"--tokenizer_model_path",
type=str,
default=None,
required=False,
help="Path to a `codestral` tokenizer file.",
)
args = parser.parse_args()

convert_mamba2_checkpoint_file_to_huggingface_model_file(
args.mamba2_checkpoint_file, args.tokenizer_model_path, args.output_dir
args.mamba2_checkpoint_directory,
args.mamba2_model_type,
args.precision,
args.output_dir,
args.tokenizer_model_path,
)

0 comments on commit 8a04372

Please sign in to comment.