Skip to content

Commit 251a240

Browse files
jamt9000zucchini-nlpamyeroberts
authored
Add llama3-llava-next-8b to llava_next conversion script (huggingface#31395)
* Add llama3-llava-next-8b to llava_next conversion script Adds support for the lmms-lab/llama3-llava-next-8b model to the convert_llava_next_weights_to_hf.py script, along with an example prompt generated from the llava_llama_3 conv_template in the LLaVA-NeXT repo. * Exclude <|begin_of_text|> from prompt example This token gets added automatically, so it should not be included in the prompt example. * Add llava-next-72b and llava-next-110b Adds the Qwen-based LLaVA-Next models to the conversion script, along with changes to load the models on multiple GPUs for inference. * Add llama3 and qwen prompt formats to docs * Chat prompt and padding side left for llama3 batched * update * Update src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove code * better naming --------- Co-authored-by: raushan <raushan@huggingface.co> Co-authored-by: Raushan Turganbay <raushan.turganbay@alumni.nu.edu.kz> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
1 parent 96a074f commit 251a240

File tree

2 files changed

+101
-35
lines changed

2 files changed

+101
-35
lines changed

docs/source/en/model_doc/llava_next.md

+11
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ print(text_prompt)
100100
"<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
101101
```
102102

103+
[llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llava-next-8b-hf) requires the following format:
104+
105+
```bash
106+
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n<image>\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
107+
```
108+
109+
[llava-next-72b-hf](https://huggingface.co/llava-hf/llava-next-72b-hf) and [llava-next-110b-hf](https://huggingface.co/llava-hf/llava-next-110b-hf) require the following format:
110+
111+
```bash
112+
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n"
113+
```
103114

104115
## Usage example
105116

src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py

+90-35
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"""
2525

2626
import argparse
27+
import gc
2728
import glob
2829
import json
2930
from pathlib import Path
@@ -111,6 +112,16 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
111112
elif model_id == "liuhaotian/llava-v1.6-34b":
112113
text_model_id = "NousResearch/Nous-Hermes-2-Yi-34B"
113114
image_token_index = 64000
115+
elif model_id == "lmms-lab/llama3-llava-next-8b":
116+
text_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
117+
image_token_index = 128256
118+
elif model_id == "lmms-lab/llava-next-72b":
119+
text_model_id = "Qwen/Qwen1.5-72B-Chat"
120+
image_token_index = 151646
121+
elif model_id == "lmms-lab/llava-next-110b":
122+
text_model_id = "Qwen/Qwen1.5-110B-Chat"
123+
image_token_index = 151646
124+
114125
vision_model_id = data["mm_vision_tower"]
115126

116127
torch.set_default_dtype(torch.float16)
@@ -120,7 +131,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
120131
tokenizer = AutoTokenizer.from_pretrained(text_model_id, use_fast=use_fast)
121132
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
122133

123-
if model_id == "liuhaotian/llava-v1.6-mistral-7b":
134+
if model_id in ("liuhaotian/llava-v1.6-mistral-7b", "lmms-lab/llama3-llava-next-8b"):
124135
# Mistral-7B doesn't have a padding token set yet
125136
tokenizer.add_special_tokens({"pad_token": "<pad>"})
126137

@@ -151,28 +162,45 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
151162

152163
# We add an image token so we resize the model
153164
# Pad to 64 for performance reasons
154-
pad_shape = 64
155-
vocab_size = config.text_config.vocab_size
156-
if model_id == "liuhaotian/llava-v1.6-34b":
157-
# this one has 3 additional tokens, namely <|startoftext|>, <|endoftext|> and <image>
158-
num_tokens = vocab_size + 3
159-
else:
160-
# this one has 2 additional tokens, namely <image> and <pad>
161-
num_tokens = vocab_size + 2
162-
model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape)
163-
model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack(
164-
tuple(
165-
(dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0]))
166-
),
167-
dim=0,
168-
)
169-
model.language_model.lm_head.weight.data[vocab_size:] = torch.stack(
170-
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))),
171-
dim=0,
172-
)
165+
# Qwen-based models have extra unused space in the vocab size already, so no need to resize
166+
if model_id not in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]:
167+
pad_shape = 64
168+
vocab_size = config.text_config.vocab_size
169+
if model_id == "liuhaotian/llava-v1.6-34b":
170+
# this one has 3 additional tokens, namely <|startoftext|>, <|endoftext|> and <image>
171+
num_tokens = vocab_size + 3
172+
else:
173+
# this one has 2 additional tokens, namely <image> and <pad>
174+
num_tokens = vocab_size + 2
175+
model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape)
176+
model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack(
177+
tuple(
178+
(
179+
dist.sample()
180+
for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0])
181+
)
182+
),
183+
dim=0,
184+
)
185+
model.language_model.lm_head.weight.data[vocab_size:] = torch.stack(
186+
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))),
187+
dim=0,
188+
)
189+
190+
print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}")
191+
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
192+
model.save_pretrained(pytorch_dump_folder_path)
193+
processor.save_pretrained(pytorch_dump_folder_path)
194+
195+
# Make space so we can load the model properly now.
196+
del state_dict
197+
gc.collect()
173198

174-
device = "cuda:2"
175-
model.to(device)
199+
# Load everything back for inference tests in float32 because prev script was written as that
200+
# Though it's mostly loaded in fp16 as original weights are in fp16
201+
model = LlavaNextForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, device_map="auto")
202+
processor = LlavaNextProcessor.from_pretrained(pytorch_dump_folder_path)
203+
device = model.device
176204

177205
# prepare inputs
178206
image = load_image()
@@ -182,6 +210,11 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
182210
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
183211
elif model_id == "liuhaotian/llava-v1.6-34b":
184212
prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
213+
elif model_id == "lmms-lab/llama3-llava-next-8b":
214+
prompt = "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n<image>\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
215+
elif model_id in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]:
216+
prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n"
217+
185218
inputs = processor(images=image, text=prompt, return_tensors="pt")
186219

187220
# verify inputs
@@ -194,8 +227,6 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
194227
original_input_ids = torch.load(filepath, map_location="cpu")
195228
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
196229
original_input_ids[original_input_ids == -200] = image_token_index
197-
print(tokenizer.decode([id for id in original_input_ids.tolist()[0] if id != -200]))
198-
199230
assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()
200231

201232
elif model_id == "liuhaotian/llava-v1.6-34b":
@@ -243,6 +274,26 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
243274
dtype=torch.float32,
244275
device=device,
245276
)
277+
elif model_id == "lmms-lab/llama3-llava-next-8b":
278+
expected_slice = torch.tensor(
279+
[[-3.9648, 1.1396, 3.3145], [-5.3594, -1.5654, -1.9619], [-12.3750, -10.6797, -9.3125]],
280+
dtype=torch.float32,
281+
device=device,
282+
)
283+
elif model_id == "lmms-lab/llava-next-72b":
284+
# Not yet checked against reference
285+
expected_slice = torch.tensor(
286+
[[3.7148, 3.9277, 3.4395], [-0.4341, 1.1387, 6.5117], [3.2324, 3.4688, 4.1133]],
287+
dtype=torch.float32,
288+
device=device,
289+
)
290+
elif model_id == "lmms-lab/llava-next-110b":
291+
# Not yet checked against reference
292+
expected_slice = torch.tensor(
293+
[[-2.5449, -1.6738, -2.0371], [1.0811, 3.4961, 5.0312], [1.7803, 2.5137, 2.4277]],
294+
dtype=torch.float32,
295+
device=device,
296+
)
246297
else:
247298
raise ValueError(f"Model {model_id} not supported")
248299

@@ -268,6 +319,12 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
268319
expected_text = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a radar chart, also known as a spider chart or star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular radar chart, there are several variables represented:\n\n- MM-Vet\n- LLa-Va-Bench\n- SEED-Bench\n- MM"
269320
elif model_id == "liuhaotian/llava-v1.6-34b":
270321
expected_text = "<|im_start|> system\nAnswer the questions. <|im_start|> user\n\nWhat is shown in this image? <|im_start|> assistant\nThe image appears to be a radar chart, also known as a spider chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular chart, there are several datasets represented by different colors and labeled with various acronyms such as MM-Vet, LLaVA-Bench, SEED-Bench, MM-Bench-CN, MM-"
322+
elif model_id == "lmms-lab/llama3-llava-next-8b":
323+
expected_text = 'system\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.user\n\n\nWhat is shown in this image?assistant\n\n\nThe image shows a radar chart, also known as a spider chart or a web chart, which is a type of graph used to display multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the values are plotted along each axis and connected to form a polygon.\n\nIn this particular radar chart, there are several axes labeled with different variables, such as "MM-Vet," "LL'
324+
elif model_id == "lmms-lab/llava-next-72b":
325+
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image displays a radar chart, also known as a spider chart or a star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the value of each variable is represented by the distance from the center of the chart to the point where the axis intersects with the line representing that variable's value.\n\nIn this particular chart, there are several axes"
326+
elif model_id == "lmms-lab/llava-next-110b":
327+
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image shows a radar chart comparing the performance of different models on various visual question answering (VQA) benchmarks. Each colored line represents a different model, and the distance from the center of the chart indicates the score or performance level of the model on a particular benchmark. The benchmarks are labeled around the edges of the chart, and include VQA v2, GQA, VizWiz, TextVQA, MMBench-CN, MME, and others. The chart allows for a"
271328
else:
272329
raise ValueError(f"Model {model_id} not supported")
273330

@@ -281,7 +338,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
281338

282339
inputs = processor(
283340
images=[image, cats_image],
284-
text=[prompt, "[INST] <image>\nHow many cats are there? [/INST]"],
341+
text=[prompt, prompt],
285342
padding=True,
286343
return_tensors="pt",
287344
).to(device)
@@ -305,16 +362,11 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
305362
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
306363
print(outputs)
307364

308-
if pytorch_dump_folder_path is not None:
309-
print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}")
310-
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
311-
model.save_pretrained(pytorch_dump_folder_path)
312-
processor.save_pretrained(pytorch_dump_folder_path)
313-
314365
if push_to_hub:
315-
repo_id = model_id.split("/")[-1]
316-
model.push_to_hub(f"llava-hf/{repo_id}-hf")
317-
processor.push_to_hub(f"llava-hf/{repo_id}-hf")
366+
checkpoint_name = model_id.split("/")[-1]
367+
print(f"Pushing to repo llava-hf/{checkpoint_name}-hf")
368+
model.push_to_hub(f"llava-hf/{checkpoint_name}-hf")
369+
processor.push_to_hub(f"llava-hf/{checkpoint_name}-hf")
318370

319371

320372
if __name__ == "__main__":
@@ -328,11 +380,14 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
328380
"liuhaotian/llava-v1.6-vicuna-7b",
329381
"liuhaotian/llava-v1.6-vicuna-13b",
330382
"liuhaotian/llava-v1.6-34b",
383+
"lmms-lab/llama3-llava-next-8b",
384+
"lmms-lab/llava-next-72b",
385+
"lmms-lab/llava-next-110b",
331386
],
332387
required=False,
333388
)
334389
parser.add_argument(
335-
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
390+
"--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model directory."
336391
)
337392
parser.add_argument(
338393
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."

0 commit comments

Comments
 (0)