Skip to content

Commit

Permalink
Merge pull request #2 from huggingface/add_mllama_processor
Browse files Browse the repository at this point in the history
Add processing code
  • Loading branch information
qubvel authored Jul 5, 2024
2 parents 462bf3c + a08d2f0 commit 5bef63c
Show file tree
Hide file tree
Showing 7 changed files with 946 additions and 59 deletions.
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,7 @@
_import_structure["models.llava_next"].append("LlavaNextImageProcessor")
_import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
_import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
_import_structure["models.mllama"].extend(["MllamaImageProcessor"])
_import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
_import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"])
_import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"])
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
("mask2former", ("Mask2FormerImageProcessor",)),
("maskformer", ("MaskFormerImageProcessor",)),
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
("mllama", ("MllamaImageProcessor",)),
("mobilenet_v1", ("MobileNetV1ImageProcessor",)),
("mobilenet_v2", ("MobileNetV2ImageProcessor",)),
("mobilevit", ("MobileViTImageProcessor",)),
Expand Down
25 changes: 24 additions & 1 deletion src/transformers/models/mllama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_vision_available,
)


_import_structure = {
Expand All @@ -33,6 +38,14 @@
"MllamaPreTrainedModel",
]

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_mllama"] = ["MllamaImageProcessor"]


if TYPE_CHECKING:
from .configuration_mllama import MllamaConfig
Expand All @@ -49,6 +62,16 @@
MllamaPreTrainedModel,
)

try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_mllama import (
MllamaImageProcessor,
)

else:
import sys

Expand Down
91 changes: 89 additions & 2 deletions src/transformers/models/mllama/convert_mllama_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from transformers import LlamaConfig, LlamaForCausalLM, MllamaConfig, CLIPVisionConfig, LlamaTokenizer, PreTrainedTokenizerFast
from transformers.convert_slow_tokenizer import TikTokenConverter

from transformers import MllamaImageProcessor

try:
from transformers import LlamaTokenizerFast
Expand Down Expand Up @@ -374,6 +374,83 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
language_model.save_pretrained(model_path, safe_serialization=safe_serialization)
shutil.rmtree(tmp_model_path)

# TODO: update to new provided code: python + video tokens
class MllamaConverter(TikTokenConverter):
def __init__(self, vocab_file, num_reserved_special_tokens=256, **kwargs):
super().__init__(vocab_file, **kwargs)
tokenizer = self.converted()
chat_template = (
"{% set loop_messages = messages %}"
"{% for message in loop_messages %}"
"{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}"
"{% if loop.index0 == 0 %}"
"{% set content = bos_token + content %}"
"{% endif %}"
"{{ content }}"
"{% endfor %}"
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
)
num_reserved_special_tokens = 256
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
"<|reserved_special_token_5|>",
"<|image|>",
] + [
f"<|reserved_special_token_{i}|>"
for i in range(6, num_reserved_special_tokens - 6)
]
tokenizer.add_special_tokens(special_tokens)

self.tokenizer = PreTrainedTokenizerFast(
tokenizer_object=tokenizer,
bos_token="<|begin_of_text|>",
eos_token="<|end_of_text|>",
chat_template=chat_template,
model_input_names=["input_ids", "attention_mask"],
)


def write_tokenizer(tokenizer_path: str, save_dir: str):

converter = MllamaConverter(
tokenizer_path,
pattern=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+", # noqa: W605
)
tokenizer = converter.tokenizer
tokenizer.save_pretrained(save_dir)


def write_image_processor(config_path: str, save_dir: str):

params = read_json(config_path)

patch_size = params["vision_chunk_size"]
max_image_splits = params["vision_max_num_chunks"]

image_processor = MllamaImageProcessor(
do_resize=True,
size={"height": patch_size, "width": patch_size},
do_rescale=True,
rescale_factor=1 / 255,
do_normalize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
do_pad=True,
do_image_splitting=True,
max_image_splits=max_image_splits,
)

image_processor.save_pretrained(save_dir)


write_model(
model_path="/home/pablo/mllama_hf/test",
Expand All @@ -382,4 +459,14 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
model_size="70B",
llama_version=3,
vocab_size=128256,
)
)

write_tokenizer(
"weights/Meta-Llama-3.1-87B-Vision-Dummy-20240624190000/tokenizer.model",
"mllama",
)

write_image_processor(
"weights/Meta-Llama-3.1-87B-Vision-Dummy-20240624190000/params.json",
"mllama",
)
Loading

0 comments on commit 5bef63c

Please sign in to comment.