Skip to content

Commit 54072f3

Browse files
mlinmgIsotr0py
andauthored
[MODEL ADDITION] Ovis2 Model Addition (vllm-project#15826)
Signed-off-by: Marco <121761685+mlinmg@users.noreply.github.com> Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent be633fb commit 54072f3

File tree

17 files changed

+1349
-7
lines changed

17 files changed

+1349
-7
lines changed

docs/source/models/supported_models.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,13 @@ See [this page](#generative-models) for more information on how to use generativ
10141014
*
10151015
* ✅︎
10161016
* ✅︎
1017+
- * `Ovis2ForConditionalGeneration`<sup>^</sup>
1018+
* Ovis2
1019+
* T + I<sup>+</sup>
1020+
* `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis2-2B`, etc.
1021+
*
1022+
*
1023+
* ✅︎
10171024
- * `PaliGemmaForConditionalGeneration`
10181025
* PaliGemma, PaliGemma 2
10191026
* T + I<sup>E</sup>

examples/offline_inference/vision_language.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,36 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
725725
)
726726

727727

728+
# Ovis2
729+
def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
730+
assert modality == "image"
731+
732+
model_name = "AIDC-AI/Ovis2-1B"
733+
tokenizer = "Isotr0py/Ovis2-tokenizer"
734+
735+
engine_args = EngineArgs(
736+
model=model_name,
737+
tokenizer=tokenizer,
738+
max_model_len=4096,
739+
max_num_seqs=2,
740+
trust_remote_code=True,
741+
dtype="half",
742+
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
743+
limit_mm_per_prompt={"image": 1},
744+
)
745+
746+
placeholder = "<image>\n"
747+
prompts = [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
748+
f"<|im_start|>user\n{placeholder}"
749+
f"{question}<|im_end|>\n"
750+
"<|im_start|>assistant\n") for question in questions]
751+
752+
return ModelRequestData(
753+
engine_args=engine_args,
754+
prompts=prompts,
755+
)
756+
757+
728758
# PaliGemma
729759
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
730760
assert modality == "image"
@@ -1041,6 +1071,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
10411071
"llama4": run_llama4,
10421072
"molmo": run_molmo,
10431073
"NVLM_D": run_nvlm_d,
1074+
"ovis2": run_ovis2,
10441075
"paligemma": run_paligemma,
10451076
"paligemma2": run_paligemma2,
10461077
"phi3_v": run_phi3v,

examples/offline_inference/vision_language_multi_image.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,36 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
436436
)
437437

438438

439+
# Ovis2
440+
def load_ovis2(question: str, image_urls: list[str]) -> ModelRequestData:
441+
model_name = "AIDC-AI/Ovis2-1B"
442+
tokenizer = "Isotr0py/Ovis2-tokenizer"
443+
444+
engine_args = EngineArgs(
445+
model=model_name,
446+
tokenizer=tokenizer,
447+
max_model_len=8192,
448+
max_num_seqs=2,
449+
trust_remote_code=True,
450+
dtype="half",
451+
limit_mm_per_prompt={"image": len(image_urls)},
452+
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
453+
)
454+
455+
placeholder = '\n'.join(
456+
[f'Image {i+1}: <image>' for i in range(len(image_urls))]) + '\n'
457+
prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
458+
f"<|im_start|>user\n{placeholder}"
459+
f"{question}<|im_end|>\n"
460+
"<|im_start|>assistant\n")
461+
462+
return ModelRequestData(
463+
engine_args=engine_args,
464+
prompt=prompt,
465+
image_data=[fetch_image(url) for url in image_urls],
466+
)
467+
468+
439469
def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
440470
model_name = "mistral-community/pixtral-12b"
441471

@@ -685,6 +715,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
685715
"mistral3": load_mistral3,
686716
"mllama": load_mllama,
687717
"NVLM_D": load_nvlm_d,
718+
"ovis2": load_ovis2,
688719
"phi3_v": load_phi3v,
689720
"phi4_mm": load_phi4mm,
690721
"pixtral_hf": load_pixtral_hf,

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,18 @@
467467
max_num_seqs=2,
468468
patch_hf_runner=model_utils.molmo_patch_hf_runner,
469469
),
470+
"ovis2": VLMTestInfo(
471+
models=["AIDC-AI/Ovis2-1B"],
472+
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
473+
prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
474+
img_idx_to_prompt=lambda idx: "<image>\n", # noqa: E501
475+
max_model_len=4096,
476+
max_num_seqs=2,
477+
dtype="half",
478+
# use sdpa mode for hf runner since ovis2 didn't work with flash_attn
479+
hf_model_kwargs={"llm_attn_implementation": "sdpa"},
480+
patch_hf_runner=model_utils.ovis2_patch_hf_runner,
481+
),
470482
"phi3v": VLMTestInfo(
471483
models=["microsoft/Phi-3.5-vision-instruct"],
472484
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),

tests/models/decoder_only/vision_language/vlm_utils/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def run_test(
6767
"disable_mm_preprocessor_cache": True,
6868
}
6969
if model_info.tokenizer:
70-
vllm_runner_kwargs_["tokenizer"] = model_info.tokenizer
70+
vllm_runner_kwargs_["tokenizer_name"] = model_info.tokenizer
7171
if model_info.tokenizer_mode:
7272
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
7373
if model_info.hf_overrides:

tests/models/decoder_only/vision_language/vlm_utils/model_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,3 +676,33 @@ def _generate(self, max_new_tokens=None, do_sample=None, **kwargs):
676676
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
677677

678678
return hf_model
679+
680+
681+
def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
682+
"""Patches and returns an instance of the HfRunner to use for Ovis2."""
683+
hf_model.model.visual_tokenizer.to(hf_model.dtype)
684+
hf_model.model.vte.to(hf_model.dtype)
685+
hf_model.model.llm.to(hf_model.dtype)
686+
687+
hf_model.model.get_output_embeddings = lambda: \
688+
hf_model.model.llm.get_output_embeddings()
689+
690+
def processor(*args, text="", images=None, **kwargs):
691+
text_tokenizer = hf_model.model.get_text_tokenizer()
692+
images = [images] if isinstance(images, Image) else images
693+
694+
text = text.split("<|im_start|>user\n")[1].split("<|im_end|>\n")[0]
695+
696+
prompt, input_ids, pixel_values = hf_model.model.preprocess_inputs(
697+
text_or_conversations=text, images=images)
698+
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
699+
700+
inputs = {
701+
"inputs": input_ids.unsqueeze(0),
702+
"pixel_values": pixel_values.unsqueeze(0),
703+
"attention_mask": attention_mask.unsqueeze(0),
704+
}
705+
return BatchFeature(data=inputs, tensor_type="pt")
706+
707+
hf_model.processor = processor
708+
return hf_model

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def _test_processing_correctness_mistral(
274274
"allenai/Molmo-7B-D-0924",
275275
"allenai/Molmo-7B-O-0924",
276276
"nvidia/NVLM-D-72B",
277+
"AIDC-AI/Ovis2-1B",
277278
"google/paligemma-3b-mix-224",
278279
"google/paligemma2-3b-ft-docci-448",
279280
"microsoft/Phi-4-multimodal-instruct",

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,10 @@ def check_available_online(
348348
max_transformers_version="4.48",
349349
transformers_version_reason="Use of deprecated imports which have been removed.", # noqa: E501
350350
extras={"phi3.5": "microsoft/Phi-3.5-vision-instruct"}), # noqa: E501
351+
"Ovis2ForConditionalGeneration": _HfExamplesInfo("AIDC-AI/Ovis2-1B",
352+
tokenizer="Isotr0py/Ovis2-tokenizer",
353+
trust_remote_code=True,
354+
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]}), # noqa: E501
351355
"Phi4MMForCausalLM": _HfExamplesInfo("microsoft/Phi-4-multimodal-instruct",
352356
trust_remote_code=True),
353357
"PixtralForConditionalGeneration": _HfExamplesInfo("mistralai/Pixtral-12B-2409", # noqa: E501

vllm/entrypoints/chat_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,9 +496,10 @@ def _placeholder_str(self, modality: ModalityStr,
496496
if model_type.startswith("llava"):
497497
return self._cached_token_str(self._tokenizer,
498498
hf_config.image_token_index)
499+
499500
if model_type in ("aya_vision", "chameleon", "deepseek_vl_v2",
500-
"internvl_chat", "skywork_chat", "NVLM_D",
501-
"h2ovl_chat", "idefics3", "smolvlm"):
501+
"internvl_chat", "ovis2", "skywork_chat",
502+
"NVLM_D", "h2ovl_chat", "idefics3", "smolvlm"):
502503
return "<image>"
503504
if model_type in ("mllama", "llama4"):
504505
return "<|image|>"

0 commit comments

Comments
 (0)