Skip to content

Conversation

@Isotr0py
Copy link
Member

@Isotr0py Isotr0py commented Oct 22, 2025

Purpose

Test Plan

python examples/offline_inference/vision_language_multi_image.py -m deepseek_ocr -n 4
pytest -s -v tests/models/multimodal/processing/test_common.py -k deepseek-ocr
pytest -s -v tests/models/multimodal/processing/test_tensor_schema.py -k deepseek-ocr

Test Result

All examples and tests should pass


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@mergify
Copy link

mergify bot commented Oct 22, 2025

Documentation preview: https://vllm--27361.org.readthedocs.build/en/27361/

@mergify mergify bot added documentation Improvements or additions to documentation deepseek Related to DeepSeek models multi-modality Related to multi-modality (#4194) labels Oct 22, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix multi-image inference for deepseek-ocr and introduces merge_by_field_config=True with TensorSchema support for better input validation. The changes are generally well-structured and the use of TensorSchema is a good improvement for code clarity and robustness. However, I've identified a critical issue in the logic for calculating the number of image patches which will likely cause a crash when processing a batch with a mix of small (untiled) and large (tiled) images. The fix is straightforward and I've provided a suggestion below.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 518 to 520
) -> NestedTensors:
images_in_this_batch = []

images_crop = images_crop.split(images_spatial_crop.prod(dim=-1).tolist())
for jdx in range(images_spatial_crop.size(0)):
patches = images_crop[jdx][0].to(torch.bfloat16)
image_ori = pixel_values[jdx]
crop_shape = images_spatial_crop[jdx][0]
patches = images_crop[jdx]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle zero local crops before splitting images_crop

The new logic assumes images_crop contains one patch for every width_tiles × height_tiles entry and unconditionally executes images_crop.split(images_spatial_crop.prod(dim=-1).tolist()). When the processor returns no local crops (e.g., each image is ≤640 px or cropping is disabled), images_crop is an empty tensor while images_spatial_crop still contains [1, 1] per image. The split sizes therefore sum to a positive number, causing RuntimeError: split_with_sizes expects split_sizes to sum exactly to dimension size and crashing even single-image inference with small images. The code needs to skip splitting or use zero lengths when no local crops exist.

Useful? React with 👍 / 👎.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py Isotr0py added this to the v0.11.1 milestone Oct 22, 2025
@Isotr0py Isotr0py added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 22, 2025
@ywang96 ywang96 merged commit 2566dca into vllm-project:main Oct 23, 2025
55 checks passed
@Isotr0py Isotr0py deleted the fix-deepseek-ocr branch October 23, 2025 03:38
usberkeley pushed a commit to usberkeley/vllm that referenced this pull request Oct 23, 2025
…ld_config=True` with tensor schema support (vllm-project#27361)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
…ld_config=True` with tensor schema support (vllm-project#27361)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
845473182 pushed a commit to raindaywhu/vllm that referenced this pull request Oct 24, 2025
…o step_forward

* 'step_forward' of https://github.com/raindaywhu/vllm: (148 commits)
  [Model] Add MoE support for NemotronH (vllm-project#25863)
  [Metrics] [KVConnector] Add connector prefix cache hit rate stats (vllm-project#26245)
  [CI] Reorganize entrypoints tests (vllm-project#27403)
  add SLA information into comparison graph for vLLM Benchmark Suite (vllm-project#25525)
  [CI/Build] Fix AMD CI: test_cpu_gpu.py (vllm-project#27388)
  [Bugfix] Fix args settings for guided decoding args (vllm-project#27375)
  [CI/Build] Fix Prithvi plugin test (vllm-project#27393)
  [Chore] Remove duplicate `has_` functions in vllm.utils (vllm-project#27372)
  [Model] Add num_cached_tokens for PoolingRequestOutput (vllm-project#27378)
  [V1][spec decode] return logprobs for spec decoding (vllm-project#26060)
  [CORE] Support Prefix Caching with Prompt Embeds (vllm-project#27219)
  [Bugfix][Core] running queue index leakage exception (vllm-project#26754)
  [Bugfix] Fix incorrect kv cache metrics in grafana.json (vllm-project#27133)
  [Bugfix] Fix SLA tuner initialization (vllm-project#27355)
  [Bugfix] Fix deepseek-ocr multi-image inference and add `merge_by_field_config=True` with tensor schema support (vllm-project#27361)
  [MLA] Bump FlashMLA (vllm-project#27354)
  [Chore] Separate out system utilities from vllm.utils (vllm-project#27201)
  [BugFix] bugfix for Flash Attention MLA with full cuda graph IMA following pr-25490 (vllm-project#27128)
  [Feature] publisher default set zmq in kv_event config (vllm-project#26915)
  [Prefix Cache] Use LoRA name for consistent KV-cache block hashing (vllm-project#27211)
  ...
kingsmad pushed a commit to kingsmad/vllm that referenced this pull request Oct 25, 2025
…ld_config=True` with tensor schema support (vllm-project#27361)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ld_config=True` with tensor schema support (vllm-project#27361)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…ld_config=True` with tensor schema support (vllm-project#27361)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants