Skip to content

Commit 7f3d4ae

Browse files
committed
[Model] Port deepseek-vl2 processor, remove dependency (vllm-project#12169)
Signed-off-by: Isotr0py <2037008807@qq.com>
1 parent 20308d5 commit 7f3d4ae

File tree

8 files changed

+385
-49
lines changed

8 files changed

+385
-49
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ steps:
5252
- tests/worker
5353
- tests/standalone_tests/lazy_torch_compile.py
5454
commands:
55-
- pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git # Used by multimoda processing test
5655
- python3 standalone_tests/lazy_torch_compile.py
5756
- pytest -v -s mq_llm_engine # MQLLMEngine
5857
- pytest -v -s async_engine # AsyncLLMEngine

docs/source/models/supported_models.md

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -767,16 +767,10 @@ See [this page](#generative-models) for more information on how to use generativ
767767
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
768768
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
769769

770-
````{note}
771-
To use `DeepSeek-VL2` series models, you need to install a fork version `deepseek_vl2` package:
772-
773-
```shell
774-
pip install git+https://github.com/Isotr0py/DeepSeek-VL2.git
770+
```{note}
771+
To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
775772
```
776773

777-
Besides, to run `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
778-
````
779-
780774
```{note}
781775
To use `TIGER-Lab/Mantis-8B-siglip-llama3`, you have to pass `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM.
782776
```

examples/offline_inference/vision_language_multi_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
393393

394394
model_example_map = {
395395
"aria": load_aria,
396-
"deepseek_vl2": load_deepseek_vl2,
396+
"deepseek_vl_v2": load_deepseek_vl2,
397397
"h2ovl_chat": load_h2onvl,
398398
"idefics3": load_idefics3,
399399
"internvl_chat": load_internvl,

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@
190190
dtype="bfloat16",
191191
),
192192
"deepseek_vl_v2": VLMTestInfo(
193-
models=["deepseek-ai/deepseek-vl2-tiny"],
193+
models=["Isotr0py/deepseek-vl2-tiny"], # model repo using dynamic module
194194
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
195195
prompt_formatter=lambda img_prompt: f"<|User|>: {img_prompt}\n\n<|Assistant|>: ", # noqa: E501
196196
max_model_len=4096,

tests/models/multimodal/processing/test_common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def _test_processing_correctness(
2222
):
2323
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
2424
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
25+
elif model_id == "deepseek-ai/deepseek-vl2-tiny":
26+
hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]}
2527
else:
2628
hf_overrides = {}
2729

@@ -139,6 +141,7 @@ def _test_processing_correctness(
139141
("rhymes-ai/Aria", {"image": True}),
140142
("Salesforce/blip2-opt-2.7b", {"image": False}),
141143
("facebook/chameleon-7b", {"image": False}),
144+
("deepseek-ai/deepseek-vl2-tiny", {"image": True}),
142145
("adept/fuyu-8b", {"image": False}),
143146
("llava-hf/llava-1.5-7b-hf", {"image": True}),
144147
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),

vllm/model_executor/models/deepseek_vl2.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
22
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
33
import math
4-
from functools import cached_property, partial
4+
from functools import cached_property
55
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
66
TypedDict, Union)
77

88
import torch
99
import torch.nn as nn
1010
import torch.nn.functional as F
1111
from einops import rearrange, repeat
12-
from transformers import AutoProcessor, BatchFeature, ProcessorMixin
12+
from transformers import BatchFeature
1313

1414
from vllm.attention import AttentionMetadata
1515
from vllm.config import VllmConfig
@@ -31,6 +31,8 @@
3131
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
3232
MlpProjectorConfig,
3333
VisionEncoderConfig)
34+
from vllm.transformers_utils.processors.deepseek_vl2 import (
35+
DeepseekVLV2Processor)
3436
from vllm.utils import is_list_of
3537

3638
from .interfaces import SupportsMultiModal, SupportsPP
@@ -129,25 +131,8 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
129131
def get_hf_config(self):
130132
return self.ctx.get_hf_config(DeepseekVLV2Config)
131133

132-
def get_hf_processor(self) -> ProcessorMixin:
133-
# TODO(Isotr0py): we should get rid of dependency on deepseek_vl2
134-
# in the future, because it's flasky and lack of maintenance.
135-
try:
136-
from deepseek_vl2.models.processing_deepseek_vl_v2 import (
137-
DeepseekVLV2Processor, select_best_resolution)
138-
AutoProcessor.register("DeepseekVLV2Processor",
139-
DeepseekVLV2Processor)
140-
except ModuleNotFoundError as exc:
141-
raise ModuleNotFoundError(
142-
"You need to `pip install "
143-
"git+https://github.com/deepseek-ai/DeepSeek-VL2.git` "
144-
"to use this model") from exc
145-
146-
processor = self.ctx.get_hf_processor(DeepseekVLV2Processor)
147-
processor.select_best_resolution = partial(
148-
select_best_resolution,
149-
candidate_resolutions=processor.candidate_resolutions)
150-
return processor
134+
def get_hf_processor(self) -> DeepseekVLV2Processor:
135+
return self.ctx.get_hf_processor(DeepseekVLV2Processor)
151136

152137
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
153138
return {"image": None}
@@ -224,31 +209,21 @@ def _call_hf_processor(
224209
mm_kwargs: Mapping[str, object],
225210
) -> BatchFeature:
226211
if mm_data:
227-
outputs = self.info.ctx.call_hf_processor(
212+
processed_outputs = self.info.ctx.call_hf_processor(
228213
self.info.get_hf_processor(**mm_kwargs),
229214
dict(prompt=prompt, **mm_data),
230215
mm_kwargs,
231216
)
232-
233-
# Deepseek-vl2 processor don't return BatchFeature,
234-
# we need to manually create it
235-
processed_outputs = dict(input_ids=outputs["input_ids"])
236-
processed_outputs = BatchFeature(data=dict(processed_outputs),
237-
tensor_type="pt")
238-
239-
# Remove batch dimension from processor outputs,
240-
# because we will try batch to create NestedTensors
241217
target_dtype = self.info.ctx.model_config.dtype
242-
pixel_values = outputs["images"].to(target_dtype).squeeze(0)
243-
images_spatial_crop = outputs["images_spatial_crop"].squeeze(0)
218+
pixel_values = processed_outputs.pop("pixel_values").to(
219+
target_dtype)
220+
# split pixel values into patches corresponding to each image
221+
images_spatial_crop = processed_outputs["images_spatial_crop"]
244222
patches_per_image = [
245223
x.prod().item() + 1 for x in images_spatial_crop
246224
]
247-
248-
# Rename `images` -> `pixel_values` to avoid confusion
249-
processed_outputs["pixel_values"] = list(
250-
pixel_values.split(patches_per_image))
251-
processed_outputs["images_spatial_crop"] = images_spatial_crop
225+
pixel_values = pixel_values.split(patches_per_image)
226+
processed_outputs["pixel_values"] = pixel_values
252227
else:
253228
tokenizer = self.info.get_tokenizer()
254229
processed_outputs = tokenizer(prompt,
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from vllm.transformers_utils.processors.deepseek_vl2 import (
2+
DeepseekVLV2Processor)
3+
4+
__all__ = ["DeepseekVLV2Processor"]

0 commit comments

Comments
 (0)