Skip to content

Conversation

@davidxia
Copy link
Contributor

@davidxia davidxia commented May 12, 2025

by changing some modules in vllm/multimodal to lazily import expensive modules like transformers or only importing them for type checkers when not used during runtime.

contributes to #14924

python -c 'import vllm'

seems slightly faster

before (main branch commit 302f3ac)

$ hyperfine 'python -c "import vllm"' --warmup 3 --runs 100 --export-markdown
 out.md
Benchmark 1: python -c "import vllm"
  Time (mean ± σ):      9.367 s ±  0.215 s    [User: 8.951 s, System: 2.028 s]
  Range (min … max):    8.931 s … 10.013 s    100 runs
Command Mean [s] Min [s] Max [s] Relative
python -c "import vllm" 9.367 ± 0.215 8.931 10.013 1.00

after (my PR commit de28f4f933760b7b53aca164ac8c2d7b5256bf11)

$ hyperfine 'python -c "import vllm"' --warmup 3 --runs 100 --export-markdown out.md
Benchmark 1: python -c "import vllm"
  Time (mean ± σ):      9.196 s ±  0.373 s    [User: 8.758 s, System: 2.065 s]
  Range (min … max):    8.837 s … 12.306 s    100 runs
Command Mean [s] Min [s] Max [s] Relative
python -c "import vllm" 9.196 ± 0.373 8.837 12.306 1.00

@davidxia davidxia marked this pull request as ready for review May 12, 2025 23:15
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the multi-modality Related to multi-modality (#4194) label May 12, 2025
Copy link
Collaborator

@aarnphm aarnphm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, but can you test out if multimodal request still works?

iirc there are some files like config.py where we have to import eagerly (probably good with this PR, but probably still worth to just to perform a quick check)

@aarnphm
Copy link
Collaborator

aarnphm commented May 13, 2025

not in the scope of this PR, but ideally we want to reduce this 8s as much as possible with lazy load

@davidxia
Copy link
Contributor Author

not in the scope of this PR, but ideally we want to reduce this 8s as much as possible with lazy load

what's "8s"?

can you test out if multimodal request still works?

Do you have an example I should run?

@aarnphm
Copy link
Collaborator

aarnphm commented May 13, 2025

what's "8s"?

from your hyperfine run, especially User: 8.758 s. This is just a reference for notes, that's all.

Do you have an example I should run?

There are a few examples in vllm/examples here

@davidxia
Copy link
Contributor Author

davidxia commented May 13, 2025

from your hyperfine run, especially User: 8.758 s. This is just a reference for notes, that's all.

ah that's right 😅

There are a few examples in vllm/examples here

thanks, I tried that earlier on a CPU platform with vllm serve llava-hf/llava-1.5-7b-hf. But running that file caused the server to crash after sending a 500 response. Do you know if that example only works on GPU platforms and isn't supported with CPU?

@aarnphm
Copy link
Collaborator

aarnphm commented May 13, 2025

Ah, let me perfor a quick test then if you don't have access to GPU

@aarnphm
Copy link
Collaborator

aarnphm commented May 13, 2025

Ah, let me perfor a quick test then if you don't have access to GPU

This works with phi 3.5 vision. You can use the diff here

diff.patch

@davidxia davidxia changed the title [Frontend] reduce vLLM's import time [Frontend] decrease import time of vllm.multimodal May 13, 2025
@davidxia
Copy link
Contributor Author

Ah, let me perfor a quick test then if you don't have access to GPU

This works with phi 3.5 vision. You can use the diff here

diff.patch

thanks, applied your patch

@davidxia davidxia force-pushed the patch3 branch 2 times, most recently from 6b526d3 to dfdb8c6 Compare May 13, 2025 12:12
@russellb russellb requested a review from Copilot May 13, 2025 12:22
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR decreases the import time of vllm.multimodal by lazily loading expensive modules and deferring certain imports to type-checking or local scopes.

  • Relocate transformers imports from top-level to type-checking blocks or local function scopes.
  • Adjust type annotations for improved runtime performance and maintain consistency across modules.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
vllm/multimodal/processing.py Moves heavy transformer imports to type-checking and removes redundant quotes.
vllm/multimodal/parse.py Shifts direct PIL.Image and BatchFeature imports to local scopes in functions.
vllm/multimodal/inputs.py Implements LazyLoader for torch and refines type aliases and annotations.

@davidxia davidxia force-pushed the patch3 branch 2 times, most recently from 904a9d4 to 05b1cbf Compare May 13, 2025 12:27
Copy link
Member

@russellb russellb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking the piecewise approach. This will be easier to review and merge.

@russellb russellb enabled auto-merge (squash) May 13, 2025 12:31
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 13, 2025
auto-merge was automatically disabled May 13, 2025 13:02

Head branch was pushed to by a user without write access

@davidxia
Copy link
Contributor Author

davidxia commented May 13, 2025

Thanks for taking the piecewise approach. This will be easier to review and merge.

Of course! LazyLoaded PIL.Image in vllm/multimodal/parse.py. Ready for another review. See diff

@aarnphm
Copy link
Collaborator

aarnphm commented May 13, 2025

@davidxia if you can apply this patch
diff.patch Thanks I don't have permission to push to your repo

auto-merge was automatically disabled May 13, 2025 14:32

Head branch was pushed to by a user without write access

@davidxia
Copy link
Contributor Author

@davidxia if you can apply this patch diff.patch Thanks I don't have permission to push to your repo

done, thanks!

@aarnphm aarnphm modified the milestone: v0.9.0 May 13, 2025
@aarnphm
Copy link
Collaborator

aarnphm commented May 13, 2025

@hmellor from the readthedocs logs it seem to build succesfully? do you know if there are any issue with this?

@hmellor
Copy link
Member

hmellor commented May 13, 2025

@hmellor from the readthedocs logs it seem to build succesfully? do you know if there are any issue with this?

RTD treats warnings as errors:

/home/docs/checkouts/readthedocs.org/user_builds/vllm/checkouts/18031/docs/source/api/summary.md:88: WARNING: Could not find vllm.multimodal.inputs.MultiModalDataDict [autodoc2.missing]
/home/docs/checkouts/readthedocs.org/user_builds/vllm/checkouts/18031/docs/source/api/summary.md:95: WARNING: Could not find vllm.multimodal.inputs.NestedTensors [autodoc2.missing]

@aarnphm
Copy link
Collaborator

aarnphm commented May 13, 2025

ah I see.

@davidxia can you move the docstring in TYPE_CHECKING down to the else block instead? Thanks.

@hmellor
Copy link
Member

hmellor commented May 13, 2025

You don't necessarily have to move it, but something with that name has to exist in the else

@davidxia
Copy link
Contributor Author

@aarnphm @hmellor I'm trying fix the sphinx warnings. I tried copying the same MultiModalDataDict docstring in the if TYPE_CHECKING into the else. But when I run python -m sphinx -T -W --keep-going -b html -d _build/doctrees -D language=en . out/html locally I still see the same warnings. Any ideas?

@aarnphm
Copy link
Collaborator

aarnphm commented May 14, 2025

MultiModalDataDict docstring in the if TYPE_CHECKING into the else. But when I run python -m sphinx -T -W --keep-going -b html -d _build/doctrees -D language=en . out/html locally I still see the same warnings. Any ideas?

probably better to keep previous change, but update the annotations to string instead

i.e:

if TYPE_CHECKING:
  import torch
 
HfImageItem: TypeAlias = Union[Image, np.ndarray, "torch.Tensor"]
"""docstring as before"""

Copy link
Collaborator

@aarnphm aarnphm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/torch.Tensor/"torch.Tensor"

@davidxia davidxia force-pushed the patch3 branch 2 times, most recently from 070c913 to 9371c2c Compare May 14, 2025 15:44
by changing some modules in `vllm/multimodal` to lazily import expensive
modules like `transformers` or only importing them for type checkers when not
used during runtime.

contributes to vllm-project#14924

Signed-off-by: David Xia <david@davidxia.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>

Co-authored-by: Aaron Pham <Aaronpham0103@gmail.com>
Co-authored-by: David Xia <david@davidxia.com>
@russellb russellb enabled auto-merge (squash) May 14, 2025 17:45
@simon-mo simon-mo disabled auto-merge May 14, 2025 22:43
@simon-mo simon-mo merged commit 749f792 into vllm-project:main May 14, 2025
57 of 59 checks passed
@davidxia davidxia deleted the patch3 branch May 14, 2025 23:08
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Co-authored-by: Aaron Pham <Aaronpham0103@gmail.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants