Skip to content

Commit b873234

Browse files
zucchini-nlpamyerobertsNielsRogge
authored
Llava: add default chat templates (huggingface#31691)
* add default chat templates * Update src/transformers/models/llava/processing_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/processing_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * more clear docstring and docs * Update docs/source/en/model_doc/llava.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/llava_next.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/vipllava.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * add tests * remove default templates (see huggingface#31733) * load chat template from another file * Update docs/source/en/model_doc/llava_next.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * revert some changes in docs * forgot vipllava * chat template file is not temporary hack * warn if loading from processor * not that file * similarly modify `save_pretrained` * Update tests/models/llava_next/test_processor_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/vipllava/test_processor_vipllava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/vipllava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/processing_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/processing_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/vipllava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/llava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/llava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/llava_next.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/llava_next.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/processing_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/llava_next.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
1 parent 271fd8e commit b873234

File tree

8 files changed

+318
-19
lines changed

8 files changed

+318
-19
lines changed

docs/source/en/model_doc/llava.md

+37-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,42 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
4040

4141
- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.
4242

43-
- For better results, we recommend users to prompt the model with the correct prompt format. Below is a list of prompt formats accepted by each llava checkpoint:
43+
- For better results, we recommend users to use the processor's `apply_chat_template()` method to format your prompt correctly. For that you need to construct a conversation history, passing in a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities, as follows:
44+
45+
```python
46+
from transformers import AutoProcessor
47+
48+
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
49+
50+
conversation = [
51+
{
52+
"role": "user",
53+
"content": [
54+
{"type": "image"},
55+
{"type": "text", "text": "What’s shown in this image?"},
56+
],
57+
},
58+
{
59+
"role": "assistant",
60+
"content": [{"type": "text", "text": "This image shows a red stop sign."},]
61+
},
62+
{
63+
64+
"role": "user",
65+
"content": [
66+
{"type": "text", "text": "Describe the image in more details."},
67+
],
68+
},
69+
]
70+
71+
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
72+
73+
# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images
74+
print(text_prompt)
75+
>>> "USER: <image>\n<What’s shown in this image? ASSISTANT: This image shows a red stop sign.</s>USER: Describe the image in more details. ASSISTANT:"
76+
```
77+
78+
- If you want to construct a chat prompt yourself, below is a list of prompt formats accepted by each llava checkpoint:
4479

4580
[llava-interleave models](https://huggingface.co/collections/llava-hf/llava-interleave-668e19a97da0036aad4a2f19) requires the following format:
4681
```bash
@@ -64,6 +99,7 @@ For multiple turns conversation:
6499
"USER: <image>\n<prompt1> ASSISTANT: <answer1></s>USER: <prompt2> ASSISTANT: <answer2></s>USER: <prompt3> ASSISTANT:"
65100
```
66101

102+
67103
### Using Flash Attention 2
68104

69105
Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [Flash Attention 2 section of performance docs](https://huggingface.co/docs/transformers/perf_infer_gpu_one).

docs/source/en/model_doc/llava_next.md

+86-10
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,61 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/
4646

4747
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
4848

49-
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. Below, we list the correct prompt formats to use for the text prompt "What is shown in this image?":
49+
- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use the processor's `apply_chat_template` to format your prompts correctly. For that you have to construct a conversation history, passing a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities. Below is an example of how to do that and the list of formats accepted by each checkpoint.
5050

51-
[llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) requires the following format:
51+
We will use [llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-hf/llava-v1.6-mistral-7b-hf) and a conversation history of text and image. Each content field has to be a list of dicts, as follows:
52+
53+
```python
54+
from transformers import LlavaNextProcessor
55+
56+
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-hf/llava-v1.6-mistral-7b-hf")
57+
58+
conversation = [
59+
{
60+
"role": "user",
61+
"content": [
62+
{"type": "image"},
63+
{"type": "text", "text": "What’s shown in this image?"},
64+
],
65+
},
66+
{
67+
"role": "assistant",
68+
"content": [{"type": "text", "text": "This image shows a red stop sign."},]
69+
},
70+
{
71+
72+
"role": "user",
73+
"content": [
74+
{"type": "text", "text": "Describe the image in more details."},
75+
],
76+
},
77+
]
5278

79+
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
80+
81+
# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images
82+
print(text_prompt)
83+
>>> "[INST] <image>\nWhat's shown in this image? [/INST] This image shows a red stop sign. [INST] Describe the image in more details. [/INST]"
84+
```
85+
86+
- If you want to construct a chat prompt yourself, below is a list of possible formats
87+
.
88+
[llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) requires the following format:
5389
```bash
5490
"[INST] <image>\nWhat is shown in this image? [/INST]"
5591
```
5692

5793
[llava-v1.6-vicuna-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf) and [llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) require the following format:
58-
5994
```bash
6095
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
6196
```
6297

6398
[llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) requires the following format:
64-
6599
```bash
66100
"<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
67101
```
68102

103+
69104
## Usage example
70105

71106
### Single image inference
@@ -86,8 +121,17 @@ model.to("cuda:0")
86121
# prepare image and text prompt, using the appropriate prompt template
87122
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
88123
image = Image.open(requests.get(url, stream=True).raw)
89-
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
90124

125+
conversation = [
126+
{
127+
"role": "user",
128+
"content": [
129+
{"type": "image"},
130+
{"type": "text", "text": "What is shown in this image?"},
131+
],
132+
},
133+
]
134+
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
91135
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
92136

93137
# autoregressively complete prompt
@@ -120,15 +164,47 @@ image_cats = Image.open(requests.get(url, stream=True).raw)
120164
url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
121165
image_snowman = Image.open(requests.get(url, stream=True).raw)
122166

123-
# Prepare a batched prompt, where the first one is a multi-turn conversation and the second is not
124-
prompt = [
125-
"[INST] <image>\nWhat is shown in this image? [/INST] There is a red stop sign in the image. [INST] <image>\nWhat about this image? How many cats do you see [/INST]",
126-
"[INST] <image>\nWhat is shown in this image? [/INST]"
167+
# Prepare a batch of two prompts, where the first one is a multi-turn conversation and the second is not
168+
conversation_1 = [
169+
{
170+
"role": "user",
171+
"content": [
172+
{"type": "image"},
173+
{"type": "text", "text": "What is shown in this image?"},
174+
],
175+
},
176+
{
177+
"role": "assistant",
178+
"content": [
179+
{"type": "text", "text": "There is a red stop sign in the image."},
180+
],
181+
},
182+
{
183+
"role": "user",
184+
"content": [
185+
{"type": "image"},
186+
{"type": "text", "text": "What about this image? How many cats do you see?"},
187+
],
188+
},
127189
]
128190

191+
conversation_2 = [
192+
{
193+
"role": "user",
194+
"content": [
195+
{"type": "image"},
196+
{"type": "text", "text": "What is shown in this image?"},
197+
],
198+
},
199+
]
200+
201+
prompt_1 = processor.apply_chat_template(conversation_1, add_generation_prompt=True)
202+
prompt_2 = processor.apply_chat_template(conversation_2, add_generation_prompt=True)
203+
prompts = [prompt_1, prompt_2]
204+
129205
# We can simply feed images in the order they have to be used in the text prompt
130206
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
131-
inputs = processor(text=prompt, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device)
207+
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device)
132208

133209
# Generate
134210
generate_ids = model.generate(**inputs, max_new_tokens=30)

docs/source/en/model_doc/vipllava.md

+41-7
Original file line numberDiff line numberDiff line change
@@ -26,30 +26,64 @@ The abstract from the paper is the following:
2626

2727
*While existing large vision-language multimodal models focus on whole image understanding, there is a prominent gap in achieving region-specific comprehension. Current approaches that use textual coordinates or spatial encodings often fail to provide a user-friendly interface for visual prompting. To address this challenge, we introduce a novel multimodal model capable of decoding arbitrary visual prompts. This allows users to intuitively mark images and interact with the model using natural cues like a "red bounding box" or "pointed arrow". Our simple design directly overlays visual markers onto the RGB image, eliminating the need for complex region encodings, yet achieves state-of-the-art performance on region-understanding tasks like Visual7W, PointQA, and Visual Commonsense Reasoning benchmark. Furthermore, we present ViP-Bench, a comprehensive benchmark to assess the capability of models in understanding visual prompts across multiple dimensions, enabling future research in this domain. Code, data, and model are publicly available.*
2828

29-
Tips:
29+
The original code can be found [here](https://github.com/mu-cai/ViP-LLaVA).
30+
31+
This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada)
32+
33+
34+
## Usage tips:
3035

3136
- The architecture is similar than llava architecture except that the multi-modal projector takes a set of concatenated vision hidden states and has an additional layernorm layer on that module.
3237

3338
- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
3439

3540
- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.
3641

37-
- For better results, we recommend users to prompt the model with the correct prompt format:
42+
- For better results, we recommend users to use the processor's `apply_chat_template()` method to format your prompt correctly. For that you need to construct a conversation history, passing in a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities, as follows:
43+
44+
```python
45+
from transformers import AutoProcessor
46+
47+
processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf")
48+
49+
conversation = [
50+
{
51+
"role": "user",
52+
"content": [
53+
{"type": "image"},
54+
{"type": "text", "text": "What’s shown in this image?"},
55+
,
56+
},
57+
{
58+
"role": "assistant",
59+
"content": [{"type": "text", "text": "This image shows a red stop sign."},]
60+
},
61+
{
62+
63+
"role": "user",
64+
"content": [
65+
{"type": "text", "text": "Describe the image in more details."},
66+
],
67+
},
68+
]
69+
70+
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
71+
72+
# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images
73+
print(text_prompt)
74+
>>> "###Human: <image>\nWhat’s shown in this image?###Assistant: This image shows a red stop sign.###Human: Describe the image in more details.###Assistant:"
75+
```
3876

77+
- If you want to construct a chat prompt yourself, below is a list of prompt formats accepted by VipLLaVa checkpoints:
3978
```bash
4079
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: <image>\n<prompt>###Assistant:
4180
```
4281

4382
For multiple turns conversation:
44-
4583
```bash
4684
A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: <image>\n<prompt1>###Assistant: <answer1>###Human: <prompt2>###Assistant:
4785
```
4886

49-
The original code can be found [here](https://github.com/mu-cai/ViP-LLaVA).
50-
51-
This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada)
52-
5387

5488
## VipLlavaConfig
5589

src/transformers/processing_utils.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
TruncationStrategy,
4040
)
4141
from .utils import (
42+
CHAT_TEMPLATE_NAME,
4243
PROCESSOR_NAME,
4344
PushToHubMixin,
4445
TensorType,
@@ -494,11 +495,21 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs):
494495
del attribute.init_kwargs["auto_map"]
495496

496497
# If we save using the predefined names, we can load using `from_pretrained`
498+
# plus we save chat_template in its own file
497499
output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
500+
output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME)
501+
502+
processor_dict = self.to_dict()
503+
chat_template = processor_dict.pop("chat_template", None)
504+
if chat_template is not None:
505+
chat_template_json_string = json.dumps({"chat_template": chat_template}, indent=2, sort_keys=True) + "\n"
506+
with open(output_chat_template_file, "w", encoding="utf-8") as writer:
507+
writer.write(chat_template_json_string)
508+
logger.info(f"chat template saved in {output_chat_template_file}")
498509

499510
# For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and
500511
# `auto_map` is not specified.
501-
if set(self.to_dict().keys()) != {"processor_class"}:
512+
if set(processor_dict.keys()) != {"processor_class"}:
502513
self.to_json_file(output_processor_file)
503514
logger.info(f"processor saved in {output_processor_file}")
504515

@@ -557,14 +568,21 @@ def get_processor_dict(
557568
is_local = os.path.isdir(pretrained_model_name_or_path)
558569
if os.path.isdir(pretrained_model_name_or_path):
559570
processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME)
571+
chat_template_file = os.path.join(pretrained_model_name_or_path, "chat_template.json")
572+
560573
if os.path.isfile(pretrained_model_name_or_path):
561574
resolved_processor_file = pretrained_model_name_or_path
575+
# cant't load chat-template when given a file as pretrained_model_name_or_path
576+
resolved_chat_template_file = None
562577
is_local = True
563578
elif is_remote_url(pretrained_model_name_or_path):
564579
processor_file = pretrained_model_name_or_path
565580
resolved_processor_file = download_url(pretrained_model_name_or_path)
581+
# can't load chat-template when given a file url as pretrained_model_name_or_path
582+
resolved_chat_template_file = None
566583
else:
567584
processor_file = PROCESSOR_NAME
585+
chat_template_file = CHAT_TEMPLATE_NAME
568586
try:
569587
# Load from local folder or from cache or download from model Hub and cache
570588
resolved_processor_file = cached_file(
@@ -581,6 +599,24 @@ def get_processor_dict(
581599
subfolder=subfolder,
582600
_raise_exceptions_for_missing_entries=False,
583601
)
602+
603+
# Load chat template from a separate json if exists
604+
# because making it part of processor-config break BC.
605+
# Processors in older version do not accept any kwargs
606+
resolved_chat_template_file = cached_file(
607+
pretrained_model_name_or_path,
608+
chat_template_file,
609+
cache_dir=cache_dir,
610+
force_download=force_download,
611+
proxies=proxies,
612+
resume_download=resume_download,
613+
local_files_only=local_files_only,
614+
token=token,
615+
user_agent=user_agent,
616+
revision=revision,
617+
subfolder=subfolder,
618+
_raise_exceptions_for_missing_entries=False,
619+
)
584620
except EnvironmentError:
585621
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
586622
# the original exception.
@@ -594,6 +630,14 @@ def get_processor_dict(
594630
f" directory containing a {PROCESSOR_NAME} file"
595631
)
596632

633+
# Add chat template as kwarg before returning because most models don't have processor config
634+
chat_template = None
635+
if resolved_chat_template_file is not None:
636+
with open(resolved_chat_template_file, "r", encoding="utf-8") as reader:
637+
text = reader.read()
638+
chat_template = json.loads(text)["chat_template"]
639+
kwargs["chat_template"] = chat_template
640+
597641
# Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
598642
# updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
599643
# (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
@@ -617,6 +661,12 @@ def get_processor_dict(
617661
else:
618662
logger.info(f"loading configuration file {processor_file} from cache at {resolved_processor_file}")
619663

664+
if "chat_template" in processor_dict and processor_dict["chat_template"] is not None:
665+
logger.warning_once(
666+
"Chat templates should be in a 'chat_template.json' file but found key='chat_template' "
667+
"in the processor's config. Make sure to move your template to its own file."
668+
)
669+
620670
if not is_local:
621671
if "auto_map" in processor_dict:
622672
processor_dict["auto_map"] = add_model_info_to_auto_map(
@@ -648,6 +698,7 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs):
648698
"""
649699
processor_dict = processor_dict.copy()
650700
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
701+
chat_template = kwargs.pop("chat_template", None)
651702

652703
# We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
653704
# If we don't pop, some specific kwargs will raise a warning
@@ -659,6 +710,8 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs):
659710

660711
unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs)
661712
processor = cls(*args, **processor_dict)
713+
if chat_template is not None:
714+
setattr(processor, "chat_template", chat_template)
662715

663716
# Update processor with kwargs if needed
664717
for key in set(kwargs.keys()):

src/transformers/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,7 @@
239239
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
240240
IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
241241
PROCESSOR_NAME = "processor_config.json"
242+
CHAT_TEMPLATE_NAME = "chat_template.json"
242243
GENERATION_CONFIG_NAME = "generation_config.json"
243244
MODEL_CARD_NAME = "modelcard.json"
244245

0 commit comments

Comments
 (0)