Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba2 conversion script for original models #32580

Merged
merged 17 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/mamba2.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ trainer.train()

[[autodoc]] Mamba2Config

## Mamba2TokenizerFast

[[autodoc]] Mamba2TokenizerFast

## Mamba2Model

[[autodoc]] Mamba2Model
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,7 @@
_import_structure["models.llama"].append("LlamaTokenizerFast")
_import_structure["models.longformer"].append("LongformerTokenizerFast")
_import_structure["models.lxmert"].append("LxmertTokenizerFast")
_import_structure["models.mamba2"].append("Mamba2TokenizerFast")
_import_structure["models.markuplm"].append("MarkupLMTokenizerFast")
_import_structure["models.mbart"].append("MBartTokenizerFast")
_import_structure["models.mbart50"].append("MBart50TokenizerFast")
Expand Down Expand Up @@ -5840,6 +5841,7 @@
from .models.llama import LlamaTokenizerFast
from .models.longformer import LongformerTokenizerFast
from .models.lxmert import LxmertTokenizerFast
from .models.mamba2 import Mamba2TokenizerFast
from .models.markuplm import MarkupLMTokenizerFast
from .models.mbart import MBartTokenizerFast
from .models.mbart50 import MBart50TokenizerFast
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("mamba2", (None, "Mamba2TokenizerFast" if is_tokenizers_available() else None)),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
(
"mbart",
Expand Down
19 changes: 17 additions & 2 deletions src/transformers/models/mamba2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,24 @@

from typing import TYPE_CHECKING

from ...file_utils import _LazyModule, is_tokenizers_available, is_torch_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)


_import_structure = {
"configuration_mamba2": ["Mamba2Config", "Mamba2OnnxConfig"],
}

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_mamba2_fast"] = ["Mamba2TokenizerFast"]

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand All @@ -41,6 +48,14 @@
if TYPE_CHECKING:
from .configuration_mamba2 import Mamba2Config, Mamba2OnnxConfig

try:
if not is_tokenizers_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_mamba2_fast import Mamba2TokenizerFast

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,188 @@
"""This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""

import argparse
import json
from functools import partial
from os import path
from typing import Dict, Optional

import torch
from safetensors import safe_open
from safetensors.torch import save_model

from transformers import LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM
from transformers import LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM, Mamba2TokenizerFast


def convert_mamba2_checkpoint_file_to_huggingface_model_file(
mamba2_checkpoint_path: str, tokenizer_model_path: str, output_dir: str
) -> None:
hf_config = Mamba2Config()
hf_model = Mamba2ForCausalLM(hf_config)
def load_state_dict_from_safetensors(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
# Load weights and config from paths
original_state_dict = {}
with safe_open(mamba2_checkpoint_path, framework="pt") as f:
with safe_open(path.join(mamba2_checkpoint_path, ckpt_name), framework="pt") as f:
for k in f.keys():
newk = k.removeprefix("model.")
original_state_dict[newk] = f.get_tensor(k).clone()
return original_state_dict


def load_state_dict_from_torch(mamba2_checkpoint_path: str, ckpt_name: str) -> Dict[str, torch.Tensor]:
return torch.load(path.join(mamba2_checkpoint_path, ckpt_name), map_location="cpu")


def convert_ssm_config_to_hf_config(config_ssm: Dict, mamba2_model_dict: Dict) -> Mamba2Config:
"""Convert a Mamba2Config from mamba_ssm to a Mamba2Config from here."""
hf_config = Mamba2Config()

# Switch to a different dict depending on model type
config_dict = mamba2_model_dict

# Set important values from config and recalculate other resulting entries
hf_config.hidden_size = config_ssm[config_dict["hidden_size"]]
hf_config.num_heads = (hf_config.hidden_size * hf_config.expand) // hf_config.head_dim
hf_config.num_hidden_layers = config_ssm[config_dict["num_hidden_layers"]]
hf_config.n_groups = config_ssm.get(config_dict["n_groups"], 1)
hf_config.residual_in_fp32 = config_ssm[config_dict["residual_in_fp32"]]
hf_config.tie_word_embeddings = config_ssm[config_dict["tie_word_embeddings"]]
hf_config.bos_token_id = config_dict["bos_token_id"]
hf_config.pad_token_id = config_dict["pad_token_id"]
hf_config.eos_token_id = config_dict["eos_token_id"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also just init with the correct values from the config_dict here. MOst of the names match so **config_dict might be simpler for things that match one to one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not so sure if I'd like that since it would get convoluted which attributes would be overriden, e.g. n_groups would be overriden in codestral but not the original paper models.


# Padded vocab size, mostly of 16 but 32 is also very common in different models
vocab_size = config_ssm[config_dict["vocab_size"]]
pad_vocab_size_multiple = config_ssm[config_dict["pad_vocab_size_multiple"]]
if (vocab_size % pad_vocab_size_multiple) != 0:
vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
hf_config.vocab_size = vocab_size

return hf_config


def load_and_save_tokenizer(
mamba2_model_type: str,
output_dir: str,
tokenizer_model_path: Optional[str] = None,
) -> None:
tokenizer = None

# Load tokenizer
if tokenizer_model_path is not None and mamba2_model_type == "codestral":
tokenizer_class = LlamaTokenizerFast
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
elif mamba2_model_type == "mamba_ssm":
tokenizer = Mamba2TokenizerFast.from_pretrained("state-spaces/mamba-130m-hf")

# Save tokenizer
if tokenizer is not None:
tokenizer.save_pretrained(output_dir)


_MAMBA2_MODELS_DICT = {
"codestral": {
"hidden_size": "dim",
"num_hidden_layers": "n_layers",
"n_groups": "n_groups",
"residual_in_fp32": "residual_in_fp32",
"tie_word_embeddings": "tie_embeddings",
"vocab_size": "vocab_size",
"pad_vocab_size_multiple": "pad_vocab_size_multiple",
"bos_token_id": 0,
"pad_token_id": 1,
"eos_token_id": 2,
"config_name": "params.json",
"load_state_dict": partial(load_state_dict_from_safetensors, ckpt_name="consolidated.safetensors"),
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "codestral"),
},
"mamba_ssm": {
"hidden_size": "d_model",
"num_hidden_layers": "n_layer",
"n_groups": "ngroups",
"residual_in_fp32": "residual_in_fp32",
"tie_word_embeddings": "tie_embeddings",
"vocab_size": "vocab_size",
"pad_vocab_size_multiple": "pad_vocab_size_multiple",
"bos_token_id": 0,
"pad_token_id": 0,
"eos_token_id": 0,
"config_name": "config.json",
"load_state_dict": partial(load_state_dict_from_torch, ckpt_name="pytorch_model.bin"),
"load_and_save_tokenizer": partial(load_and_save_tokenizer, "mamba_ssm"),
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for any parameters that have similar names, IMO it's useless to have them in the mapping. (like bocab size, residual etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True that! I will change that later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified it now, and removed residual_in_fp32 entirely as the default values are shared across all of them.

}


def convert_mamba2_checkpoint_file_to_huggingface_model_file(
mamba2_checkpoint_path: str,
mamba2_model_type: str,
precision: str,
output_dir: str,
tokenizer_model_path: Optional[str] = None,
) -> None:
mamba2_model_dict = _MAMBA2_MODELS_DICT[mamba2_model_type]

# Load and save config based on name
config_path = path.join(mamba2_checkpoint_path, mamba2_model_dict["config_name"])
with open(config_path, "r", encoding="utf-8") as json_file:
config = json.load(json_file)
hf_config = convert_ssm_config_to_hf_config(config_ssm=config, mamba2_model_dict=mamba2_model_dict)
hf_config.save_pretrained(output_dir)

# Load state dict of the original model and transfer to hf model
original_state_dict = mamba2_model_dict["load_state_dict"](mamba2_checkpoint_path=mamba2_checkpoint_path)
hf_model = Mamba2ForCausalLM(hf_config)
hf_model.load_state_dict(original_state_dict)

# Save new model to pytorch_dump_path
hf_model.to(torch.bfloat16).save_pretrained(output_dir)
tokenizer_class = LlamaTokenizerFast
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
tokenizer.save_pretrained(output_dir)
dtype = torch.float32 if precision == "fp32" else (torch.bfloat16 if precision == "bf16" else torch.float16)
save_model(hf_model.to(dtype), path.join(output_dir, "model.safetensors"), metadata={"format": "pt"})

# Load and save tokenizer
mamba2_model_dict["load_and_save_tokenizer"](output_dir=output_dir, tokenizer_model_path=tokenizer_model_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--mamba2_checkpoint_file",
"--mamba2_checkpoint_directory",
type=str,
required=True,
help="Path to a `pytorch_model.bin` mamba2_ssm checkpoint file to be converted.",
help="Path to a directory containing the `pytorch_model.bin` or `.safetensors` mamba2_ssm checkpoint file to be converted.",
)
parser.add_argument(
"-c",
"--tokenizer_model_path",
"-m",
"--mamba2_model_type",
type=str,
default="mamba_ssm",
const="mamba_ssm",
required=True,
choices=("codestral", "mamba_ssm"),
help="The model type the conversion will be performed on. Can choose from either `codestral` or `mamba_ssm`.",
)
Comment on lines 154 to +163
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User-dependent choice on conversion. For now, defaults to mamba_ssm as codestral has a repo which can be used so I'd expect more people to use it locally for the paper models.

parser.add_argument(
"-p",
"--precision",
type=str,
default="fp16",
const="fp16",
required=True,
help="Path to a `config.json` file corresponding to a Mamba2Config of the original mamba2_ssm model.",
choices=("fp32", "fp16", "bf16"),
help="The precision the model will be saved in. Select from fp32, fp16 or bf16.",
)
parser.add_argument(
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
)
parser.add_argument(
"-t",
"--tokenizer_model_path",
type=str,
default=None,
required=False,
help="Path to a `codestral` tokenizer file.",
)
args = parser.parse_args()

convert_mamba2_checkpoint_file_to_huggingface_model_file(
args.mamba2_checkpoint_file, args.tokenizer_model_path, args.output_dir
args.mamba2_checkpoint_directory,
args.mamba2_model_type,
args.precision,
args.output_dir,
args.tokenizer_model_path,
)
36 changes: 36 additions & 0 deletions src/transformers/models/mamba2/tokenization_mamba2_fast.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this. Overall this should still be up to the user, and the tokenizers should be set to left when converting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally wanted to use the gpt neox tokenizer of mamba1 and override the padding side but I wasn't aware how to do it. So when saving and reloading it is back to the non-overriden padding side. (see #32580 (comment))

That's how this separate tokenizer came to be. Is there an easy way to override a padding side? So loading the tokenizer, override the padding side, and on reload the padding side should persist 👀

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given this:

self.padding_side = kwargs.pop("padding_side", self.padding_side)
I think we should update the GPTTokenizerFast to take padding_side as input and a test making sure loading / saving does not overwrite it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow, it's so simple 😆

Removed the mamba2 tokenizer now and added padding_side as kwarg to the gptneox tokenizer. I tested locally to see if the padding side persists and it indeed does so. There is no test file for the gptneox tokenizer, should I create one or is there somewhere else I should write this test in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the kwarg again from the tokenizer file as it's passed even without it. Testing still remains 👀

Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# coding=utf-8
# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Mamba2."""

from ...utils import logging
from ..gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast


logger = logging.get_logger(__name__)


class Mamba2TokenizerFast(GPTNeoXTokenizerFast):
"""
Utility class to overwrite the padding side of a GPTNeoXTokenizerFast tokenizer.
"""

padding_side = "left"

def __init__(self, *args, **kwargs):
# Silently remove padding side on init
kwargs.pop("padding_side", None)

# Otherwise we take over all other parameters
super().__init__(*args, **kwargs, padding_side="left")
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_tokenizers_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])


class Mamba2TokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["tokenizers"])


class MarkupLMTokenizerFast(metaclass=DummyObject):
_backends = ["tokenizers"]

Expand Down