Skip to content

[Bugfix] Fix deepseek-vl2 inference with more than 2 images #13818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 25, 2025

Conversation

Isotr0py
Copy link
Collaborator

@Isotr0py Isotr0py commented Feb 25, 2025

FIX #13396 (link existing issues this PR will resolve)

  • Deepseek-vl2 will crop image to (image_size, image_size) instead of select_best_resolution when num_images>2

Signed-off-by: Isotr0py <2037008807@qq.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

Since the processing is different depending on number of images, we have to disable the processing cache when max number of images is greater than one (similar to the case for H2O-VL)

Signed-off-by: Isotr0py <2037008807@qq.com>
@DarkLight1337
Copy link
Member

Can you also update H2O-VL to only disable the cache when multi-image is enabled?

Signed-off-by: Isotr0py <2037008807@qq.com>
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM now

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) February 25, 2025 11:39
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 25, 2025
@simon-mo simon-mo merged commit 6ff5186 into vllm-project:main Feb 25, 2025
32 of 36 checks passed
@Isotr0py Isotr0py deleted the fix-dsvl2-repl branch February 25, 2025 14:06
@SLKun
Copy link

SLKun commented Mar 7, 2025

I faced some problems when start deepseek-ai/deepseek-vl. Here is the command line:

vllm serve --hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}' deepseek-ai/deepseek-vl2

Here is the traceback, it seems that this issue is related to this RP.

INFO 03-07 14:56:09 [model_runner.py:1117] Loading model weights took 56.2208 GB and 15.182009 seconds
ERROR 03-07 14:56:09 [engine.py:400] 'image'
ERROR 03-07 14:56:09 [engine.py:400] Traceback (most recent call last):
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 391, in run_mp_engine
ERROR 03-07 14:56:09 [engine.py:400]     engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
ERROR 03-07 14:56:09 [engine.py:400]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 124, in from_engine_args
ERROR 03-07 14:56:09 [engine.py:400]     return cls(ipc_path=ipc_path,
ERROR 03-07 14:56:09 [engine.py:400]            ^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 76, in __init__
ERROR 03-07 14:56:09 [engine.py:400]     self.engine = LLMEngine(*args, **kwargs)
ERROR 03-07 14:56:09 [engine.py:400]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 276, in __init__
ERROR 03-07 14:56:09 [engine.py:400]     self._initialize_kv_caches()
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 421, in _initialize_kv_caches
ERROR 03-07 14:56:09 [engine.py:400]     self.model_executor.determine_num_available_blocks())
ERROR 03-07 14:56:09 [engine.py:400]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 102, in determine_num_available_blocks
ERROR 03-07 14:56:09 [engine.py:400]     results = self.collective_rpc("determine_num_available_blocks")
ERROR 03-07 14:56:09 [engine.py:400]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 03-07 14:56:09 [engine.py:400]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 03-07 14:56:09 [engine.py:400]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/utils.py", line 2232, in run_method
ERROR 03-07 14:56:09 [engine.py:400]     return func(*args, **kwargs)
ERROR 03-07 14:56:09 [engine.py:400]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 03-07 14:56:09 [engine.py:400]     return func(*args, **kwargs)
ERROR 03-07 14:56:09 [engine.py:400]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
ERROR 03-07 14:56:09 [engine.py:400]     self.model_runner.profile_run()
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 03-07 14:56:09 [engine.py:400]     return func(*args, **kwargs)
ERROR 03-07 14:56:09 [engine.py:400]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1238, in profile_run
ERROR 03-07 14:56:09 [engine.py:400]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1282, in _dummy_run
ERROR 03-07 14:56:09 [engine.py:400]     max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
ERROR 03-07 14:56:09 [engine.py:400]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/multimodal/registry.py", line 324, in get_max_multimodal_tokens
ERROR 03-07 14:56:09 [engine.py:400]     return sum(self.get_max_tokens_by_modality(model_config).values())
ERROR 03-07 14:56:09 [engine.py:400]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/multimodal/registry.py", line 306, in get_max_tokens_by_modality
ERROR 03-07 14:56:09 [engine.py:400]     mm_limits = self.get_mm_limits_per_prompt(model_config)
ERROR 03-07 14:56:09 [engine.py:400]                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/multimodal/registry.py", line 375, in get_mm_limits_per_prompt
ERROR 03-07 14:56:09 [engine.py:400]     processor = self.create_processor(model_config, tokenizer)
ERROR 03-07 14:56:09 [engine.py:400]                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/multimodal/registry.py", line 450, in create_processor
ERROR 03-07 14:56:09 [engine.py:400]     return factories.build_processor(ctx, cache=cache)
ERROR 03-07 14:56:09 [engine.py:400]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/multimodal/registry.py", line 83, in build_processor
ERROR 03-07 14:56:09 [engine.py:400]     return self.processor(info, dummy_inputs_builder, cache=cache)
ERROR 03-07 14:56:09 [engine.py:400]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400]   File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/model_executor/models/deepseek_vl2.py", line 234, in __init__
ERROR 03-07 14:56:09 [engine.py:400]     if self.cache is not None and mm_limit["image"] > 2:
ERROR 03-07 14:56:09 [engine.py:400]                                   ~~~~~~~~^^^^^^^^^
ERROR 03-07 14:56:09 [engine.py:400] KeyError: 'image'
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 402, in run_mp_engine
    raise e
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 391, in run_mp_engine
    engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 124, in from_engine_args
    return cls(ipc_path=ipc_path,
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py", line 76, in __init__
    self.engine = LLMEngine(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 276, in __init__
    self._initialize_kv_caches()
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 421, in _initialize_kv_caches
    self.model_executor.determine_num_available_blocks())
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 102, in determine_num_available_blocks
    results = self.collective_rpc("determine_num_available_blocks")
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/utils.py", line 2232, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
    self.model_runner.profile_run()
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1238, in profile_run
    self._dummy_run(max_num_batched_tokens, max_num_seqs)
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1282, in _dummy_run
    max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/multimodal/registry.py", line 324, in get_max_multimodal_tokens
    return sum(self.get_max_tokens_by_modality(model_config).values())
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/multimodal/registry.py", line 306, in get_max_tokens_by_modality
    mm_limits = self.get_mm_limits_per_prompt(model_config)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/multimodal/registry.py", line 375, in get_mm_limits_per_prompt
    processor = self.create_processor(model_config, tokenizer)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/multimodal/registry.py", line 450, in create_processor
    return factories.build_processor(ctx, cache=cache)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/multimodal/registry.py", line 83, in build_processor
    return self.processor(info, dummy_inputs_builder, cache=cache)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/model_executor/models/deepseek_vl2.py", line 234, in __init__
    if self.cache is not None and mm_limit["image"] > 2:
                                  ~~~~~~~~^^^^^^^^^
KeyError: 'image'
[rank0]:[W307 14:56:09.564557541 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
Traceback (most recent call last):
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/bin/vllm", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/entrypoints/cli/main.py", line 73, in main
    args.dispatch_function(args)
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/entrypoints/cli/serve.py", line 34, in cmd
    uvloop.run(run_server(args))
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/uvloop/__init__.py", line 109, in run
    return __asyncio.run(
           ^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/uvloop/__init__.py", line 61, in wrapper
    return await main
           ^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 946, in run_server
    async with build_async_engine_client(args) as engine_client:
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 138, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/share/home/defaultTenant/longyb1/.conda/envs/vllm/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py", line 232, in build_async_engine_client_from_engine_args
    raise RuntimeError(
RuntimeError: Engine process failed to start. See stack trace for the root cause.

@DarkLight1337
Copy link
Member

Thanks for reporting! It should be fixed by #14417

lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: ValueError: Attempted to assign 421 + 421 + 421 + 421 + 421 + 421 = 2526 multimodal tokens to 12606 placeholders
4 participants