Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLM: special multimodal Tokenizer #34461

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
19 changes: 19 additions & 0 deletions docs/source/en/main_classes/tokenizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,25 @@ token space (e.g., getting the index of the token comprising a given character o
to a given token).


# Multimodal Tokenizer

Apart from that each tokenizer can be a "multimodal" tokenizer which means that the tokenizer will hold all relevant special tokens
as part of tokenizer attributes for easier access. For example, if the tokenizer is loaded from a vision-language model like LLaVA, you will
be able to access `tokenizer.image_token_id` to obtain the special image token used as a placeholder.

To enable extra special tokens for any type of tokenizer, you have to add the following lines and save the tokenizer. Extra special tokens do not
have to be modality related and can ne anything that the model often needs access to. In the below code, tokenizer at `output_dir` will have direct access
to three more special tokens.

```python
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.extra_special_tokens = ["image_token", "boi_token", "eoi_token"]
tokenizer.save_pretrained(output_dir)

vision_tokenizer = AutoTokenizer.save_pretrained(output_dir)
vision_tokenizer.image_token = "IMAGE"
```

## PreTrainedTokenizer

[[autodoc]] PreTrainedTokenizer
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/blip_2/processing_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ class Blip2Processor(ProcessorMixin):
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
tokenizer.return_token_type_ids = False
self.current_processor = image_processor
self.image_token = AddedToken("<image>", normalized=False, special=True)
tokenizer.add_tokens([self.image_token], special_tokens=True)
if not hasattr(tokenizer, "image_token"):
self.image_token = AddedToken("<image>", normalized=False, special=True)
tokenizer.add_tokens([self.image_token], special_tokens=True)
else:
self.image_token = tokenizer.image_token
self.num_query_tokens = num_query_tokens

super().__init__(image_processor, tokenizer)
Expand Down
9 changes: 6 additions & 3 deletions src/transformers/models/chameleon/processing_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ class ChameleonProcessor(ProcessorMixin):

def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
self.image_seq_length = image_seq_length
self.image_token = image_token
self.image_start_token = "<racm3:break>" # fixed tokens for start and end, so can hardcode
self.image_end_token = "<eoss>"
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.image_start_token = (
tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
) # fixed tokens for start and end, so can hardcode
self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"

super().__init__(image_processor, tokenizer)

def __call__(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gemma/tokenization_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __getstate__(self):
return state

def __setstate__(self, d):
self.__dict__ = d
self.__dict__.update(d)
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)

Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/idefics/processing_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,11 @@ def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_u

super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
self.image_token_id = (
tokenizer.image_token_id
if hasattr(tokenizer, "image_token")
else tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
)

self.default_image_dims = (
self.image_processor.image_num_channels,
Expand Down
17 changes: 10 additions & 7 deletions src/transformers/models/idefics2/processing_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,19 @@ def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, cha
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")

self.fake_image_token = AddedToken("<fake_token_around_image>", normalized=False, special=True)
self.image_token = AddedToken("<image>", normalized=False, special=True)
if not hasattr(tokenizer, "image_token"):
self.fake_image_token = AddedToken("<fake_token_around_image>", normalized=False, special=True)
self.image_token = AddedToken("<image>", normalized=False, special=True)
tokens_to_add = {"additional_special_tokens": [self.fake_image_token, self.image_token]}
tokenizer.add_special_tokens(tokens_to_add)
else:
self.fake_image_token = tokenizer.image_boundary_token
self.image_token = tokenizer.image_token

self.end_of_utterance_token = AddedToken("<end_of_utterance>", normalized=False, special=True)
tokenizer.add_special_tokens({"additional_special_tokens": [self.end_of_utterance_token]})
self.image_seq_len = image_seq_len

tokens_to_add = {
"additional_special_tokens": [self.fake_image_token, self.image_token, self.end_of_utterance_token]
}
tokenizer.add_special_tokens(tokens_to_add)

super().__init__(image_processor, tokenizer, chat_template=chat_template)

def _extract_images_from_prompts(self, prompts):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,11 @@ class InstructBlipProcessor(ProcessorMixin):
qformer_tokenizer_class = "AutoTokenizer"

def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
self.image_token = AddedToken("<image>", normalized=False, special=True)
tokenizer.add_tokens([self.image_token], special_tokens=True)
if not hasattr(tokenizer, "image_token"):
self.image_token = AddedToken("<image>", normalized=False, special=True)
tokenizer.add_tokens([self.image_token], special_tokens=True)
else:
self.image_token = tokenizer.image_token
self.num_query_tokens = num_query_tokens
super().__init__(image_processor, tokenizer, qformer_tokenizer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ class InstructBlipVideoProcessor(ProcessorMixin):
qformer_tokenizer_class = "AutoTokenizer"

def __init__(self, image_processor, tokenizer, qformer_tokenizer, num_query_tokens=None, **kwargs):
self.video_token = AddedToken("<video>", normalized=False, special=True)
tokenizer.add_tokens([self.video_token], special_tokens=True)
if not hasattr(tokenizer, "video_token"):
self.video_token = AddedToken("<video>", normalized=False, special=True)
tokenizer.add_tokens([self.video_token], special_tokens=True)
else:
self.video_token = tokenizer.video_token
self.num_query_tokens = num_query_tokens
super().__init__(image_processor, tokenizer, qformer_tokenizer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __getstate__(self):
return state

def __setstate__(self, d):
self.__dict__ = d
self.__dict__.update(d)

# for backward compatibility
if not hasattr(self, "sp_model_kwargs"):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llama/tokenization_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def __getstate__(self):
return state

def __setstate__(self, d):
self.__dict__ = d
self.__dict__.update(d)
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
):
self.patch_size = patch_size
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = image_token
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
super().__init__(image_processor, tokenizer, chat_template=chat_template)

def __call__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
):
self.patch_size = patch_size
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = image_token
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
super().__init__(image_processor, tokenizer, chat_template=chat_template)

def __call__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def __init__(
):
self.patch_size = patch_size
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = image_token
self.video_token = video_token
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
super().__init__(video_processor, image_processor, tokenizer, chat_template=chat_template)

def __call__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def __init__(
):
self.num_image_tokens = num_image_tokens
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = image_token
self.video_token = video_token
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)

def __call__(
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/models/mllama/processing_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,13 @@ class MllamaProcessor(ProcessorMixin):
tokenizer_class = "PreTrainedTokenizerFast"

def __init__(self, image_processor, tokenizer):
self.image_token = "<|image|>"
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
if not hasattr(tokenizer, "image_token"):
self.image_token = "<|image|>"
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
else:
self.image_token = tokenizer.image_token
self.image_token_id = tokenizer.image_token_id

self.python_token = "<|python_tag|>"
self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
self.bos_token = tokenizer.bos_token
Expand Down
12 changes: 8 additions & 4 deletions src/transformers/models/paligemma/processing_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,15 @@ def __init__(

self.image_seq_length = image_processor.image_seq_length

image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
tokens_to_add = {"additional_special_tokens": [image_token]}
tokenizer.add_special_tokens(tokens_to_add)
if not hasattr(tokenizer, "image_token"):
image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
tokens_to_add = {"additional_special_tokens": [image_token]}
tokenizer.add_special_tokens(tokens_to_add)
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
else:
self.image_token_id = tokenizer.image_token_id

tokenizer.add_tokens(EXTRA_TOKENS)
self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False

Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/qwen2_vl/processing_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class Qwen2VLProcessor(ProcessorMixin):
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")

def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
super().__init__(image_processor, tokenizer, chat_template=chat_template)

def __call__(
Expand Down Expand Up @@ -132,23 +134,23 @@ def __call__(
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while "<|image_pad|>" in text[i]:
while self.image_token in text[i]:
text[i] = text[i].replace(
"<|image_pad|>", "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
self.image_token, "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1
)
index += 1
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
text[i] = text[i].replace("<|placeholder|>", self.image_token)

if video_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while "<|video_pad|>" in text[i]:
while self.video_token in text[i]:
text[i] = text[i].replace(
"<|video_pad|>", "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1
self.video_token, "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), 1
)
index += 1
text[i] = text[i].replace("<|placeholder|>", "<|video_pad|>")
text[i] = text[i].replace("<|placeholder|>", self.video_token)

text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/udop/tokenization_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def __getstate__(self):
return state

def __setstate__(self, d):
self.__dict__ = d
self.__dict__.update(d)
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/video_llava/processing_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def __init__(
):
self.patch_size = patch_size
self.vision_feature_select_strategy = vision_feature_select_strategy
self.image_token = image_token
self.video_token = video_token
self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
self.video_token = tokenizer.video_token if hasattr(tokenizer, "video_token") else video_token
super().__init__(image_processor, tokenizer, chat_template=chat_template)

def __call__(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_to
token_index = current_vocab[token.content]

if token.special and str(token) not in self.all_special_tokens:
self._additional_special_tokens.append(token)
self._special_tokens_map["additional_special_tokens"].append(token)
# the setter automatically updates the reverse map
self._added_tokens_decoder[token_index] = token
self._added_tokens_encoder[token.content] = token_index
Expand Down
Loading
Loading