Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ModelRequestData(NamedTuple):
stop_token_ids: list[int] | None = None
chat_template: str | None = None
lora_requests: list[LoRARequest] | None = None
sampling_params: SamplingParams | None = None


# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
Expand Down Expand Up @@ -201,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
)


def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData:
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor

model_name = "deepseek-ai/DeepSeek-OCR"

engine_args = EngineArgs(
model=model_name,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
logits_processors=[NGramPerReqLogitsProcessor],
)

placeholder = "<image>\n" * len(image_urls)
prompt = placeholder + question

# The following sampling params config is taken from
# the official Deepseek-OCR inference example.
# (IMPORTANT) Use the custom logits processor and avoid skipping
# special tokens for this model for the optimal OCR performance.
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
# ngram logit processor args
extra_args=dict(
ngram_size=30,
window_size=90,
# whitelist: <td>, </td>
whitelist_token_ids={128821, 128822},
),
skip_special_tokens=False,
)

return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
sampling_params=sampling_params,
)


def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it"

Expand Down Expand Up @@ -1253,6 +1294,7 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData:
"bee": load_bee,
"command_a_vision": load_command_a_vision,
"deepseek_vl_v2": load_deepseek_vl2,
"deepseek_ocr": load_deepseek_ocr,
"gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl,
"hyperclovax_seed_vision": load_hyperclovax_seed_vision,
Expand Down Expand Up @@ -1325,8 +1367,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: int | None)
engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args)

sampling_params = SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
sampling_params = (
SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
)
if req_data.sampling_params is None
else req_data.sampling_params
)
outputs = llm.chat(
[
Expand Down
1 change: 1 addition & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def _test_processing_correctness_one(
"facebook/chameleon-7b",
"CohereLabs/command-a-vision-07-2025",
"deepseek-ai/deepseek-vl2-tiny",
"deepseek-ai/DeepSeek-OCR",
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
"adept/fuyu-8b",
"google/gemma-3-4b-it",
Expand Down
113 changes: 58 additions & 55 deletions vllm/model_executor/models/deepseek_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal

import torch
import torch.nn as nn
Expand Down Expand Up @@ -53,6 +54,7 @@
count_tiles,
)
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.v1.sample.logits_processor import (
AdapterLogitsProcessor,
RequestLogitsProcessor,
Expand All @@ -65,6 +67,28 @@
_IMAGE_TOKEN = "<image>"


class DeepseekOCRImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- n: Number of images
- p: Number of patches
- base_size: Base size of the processor
- image_size: Image size of the processor
"""

type: Literal["pixel_values"]
data: Annotated[
torch.Tensor,
TensorShape("bn", 3, "base_size", "base_size", dynamic_dims={"bnp"}),
]
images_crop: Annotated[
torch.Tensor,
TensorShape("bnp", 3, "image_size", "image_size", dynamic_dims={"bnp"}),
]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]


class NoRepeatNGramLogitsProcessor:
def __init__(
self,
Expand Down Expand Up @@ -260,10 +284,15 @@ def _get_mm_fields_config(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
images_spatial_crop = hf_inputs.get("images_spatial_crop", torch.empty((0, 2)))
is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
images_spatial_crop=MultiModalFieldConfig.batched("image"),
images_crop=MultiModalFieldConfig.batched("image"),
images_crop=MultiModalFieldConfig.flat_from_sizes(
"image", patches_per_image
),
)

def _get_prompt_updates(
Expand Down Expand Up @@ -302,42 +331,15 @@ def get_replacement_deepseek_vl2(item_idx: int):
)
]

# TODO(Isotr0py): Check if we still need this workaround for
# deepseek-ocr processor.
# def _cached_apply_hf_processor(
# self,
# prompt: str | list[int],
# mm_data_items: MultiModalDataItems,
# hf_processor_mm_kwargs: Mapping[str, object],
# tokenization_kwargs: Mapping[str, object],
# mm_uuids: MultiModalUUIDDict | None = None,
# ) -> tuple[list[int], MultiModalKwargs, bool]:
# # The processor logic is different for len(images) <= 2 vs > 2
# # Since the processing cache assumes that the processor output is
# # invariant of how many images are passed per prompt, we only
# # perform caching for the most common case
# if mm_data_items.get_count("image", strict=False) > 2:
# # This code path corresponds to the cache being disabled
# return self._apply_hf_processor_main(
# prompt=prompt,
# mm_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# enable_hf_prompt_update=True,
# )

# return super()._cached_apply_hf_processor(
# prompt=prompt,
# mm_data_items=mm_data_items,
# hf_processor_mm_kwargs=hf_processor_mm_kwargs,
# )


@MULTIMODAL_REGISTRY.register_processor(
DeepseekOCRMultiModalProcessor,
info=DeepseekOCRProcessingInfo,
dummy_inputs=DeepseekOCRDummyInputsBuilder,
)
class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# map prefix for language backbone
Expand Down Expand Up @@ -389,6 +391,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.vision_model = DeepCLIPVisionTransformer(
config=clip_vision_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_model"),
)

self.projector = MlpProjector(self.projector_config)
Expand Down Expand Up @@ -426,7 +429,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.language_model.make_empty_intermediate_tensors
)

def _parse_and_validate_image_input(self, **kwargs: object):
def _parse_and_validate_image_input(
self, **kwargs: object
) -> DeepseekOCRImagePixelInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
images_spatial_crop = kwargs.pop("images_spatial_crop", None)
images_crop = kwargs.pop("images_crop", None)
Expand All @@ -435,23 +440,16 @@ def _parse_and_validate_image_input(self, **kwargs: object):
return None

if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
)

if not isinstance(images_spatial_crop, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image sizes. "
f"Got type: {type(images_spatial_crop)}"
)

if not isinstance(images_crop, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image crop. Got type: {type(images_crop)}"
)

return [pixel_values, images_crop, images_spatial_crop]
base_size = self.vision_config.image_size
return DeepseekOCRImagePixelInputs(
type="pixel_values",
data=pixel_values,
images_crop=images_crop,
images_spatial_crop=images_spatial_crop,
resolve_bindings={
"base_size": base_size,
},
)

raise AssertionError("This line should be unreachable.")

Expand Down Expand Up @@ -518,10 +516,13 @@ def _pixel_values_to_embedding(
) -> NestedTensors:
images_in_this_batch = []

is_tiled = (images_spatial_crop[:, 0] > 1) | (images_spatial_crop[:, 1] > 1)
patches_per_image = torch.where(is_tiled, images_spatial_crop.prod(dim=-1), 0)
images_crop = images_crop.split(patches_per_image.tolist())
for jdx in range(images_spatial_crop.size(0)):
patches = images_crop[jdx][0].to(torch.bfloat16)
image_ori = pixel_values[jdx]
crop_shape = images_spatial_crop[jdx][0]
patches = images_crop[jdx]
image_ori = pixel_values[[jdx]]
crop_shape = images_spatial_crop[jdx]

global_features = self._encode_global_features(image_ori)
local_features = self._encode_local_features(patches, crop_shape)
Expand All @@ -540,10 +541,12 @@ def _pixel_values_to_embedding(

return images_in_this_batch

def _process_image_input(self, image_input) -> torch.Tensor:
pixel_values = image_input[0].to(torch.bfloat16)
images_crop = image_input[1]
images_spatial_crop = image_input[2].to(dtype=torch.long)
def _process_image_input(
self, image_input: DeepseekOCRImagePixelInputs
) -> torch.Tensor:
pixel_values = image_input.data
images_crop = image_input.images_crop
images_spatial_crop = image_input.images_spatial_crop.to(dtype=torch.long)

vision_features = self._pixel_values_to_embedding(
pixel_values=pixel_values,
Expand Down
14 changes: 5 additions & 9 deletions vllm/transformers_utils/processors/deepseek_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,20 +411,16 @@ def tokenize_with_images(
images_seq_mask = images_seq_mask[:-1]

if len(images_list) == 0:
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
pixel_values = torch.zeros((0, 3, self.base_size, self.base_size))
images_spatial_crop = torch.zeros((0, 2), dtype=torch.long)
images_crop = torch.zeros((0, 3, self.image_size, self.image_size))
else:
pixel_values = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
if images_crop_list:
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
images_crop = torch.stack(images_crop_list, dim=0)
else:
images_crop = torch.zeros(
(1, 3, self.image_size, self.image_size)
).unsqueeze(0)
images_crop = torch.zeros((0, 3, self.image_size, self.image_size))

input_ids = input_ids.unsqueeze(0)

Expand Down