Skip to content

Conversation

@lengrongfu
Copy link
Contributor

@lengrongfu lengrongfu commented May 14, 2025

Issue: #15697

  • olmo.py
python3 -m vllm.entrypoints.cli.main serve allenai/OLMo-1B-hf --trust-remote-code --gpu-memory-utilization 0.95 
  • olmo2.py
python3 -m vllm.entrypoints.cli.main serve allenai/OLMo-2-0425-1B --trust-remote-code --gpu-memory-utilization 0.95 
  • mixtral_quant.py
python3 -m vllm.entrypoints.cli.main serve TitanML/tiny-mixtral --trust-remote-code --gpu-memory-utilization 0.95
  • solar.py
python3 -m vllm.entrypoints.cli.main serve upstage/solar-pro-preview-instruct --trust-remote-code --gpu-memory-utilization 0.95
  • nemotron.py
python3 -m vllm.entrypoints.cli.main serve nvidia/Minitron-4B-Base --trust-remote-code --gpu-memory-utilization 0.95

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@lengrongfu lengrongfu force-pushed the feat/new1-use-autoweights branch 2 times, most recently from f1c836e to d7e6366 Compare May 14, 2025 16:56
@mergify
Copy link

mergify bot commented May 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @lengrongfu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 15, 2025
@lengrongfu lengrongfu force-pushed the feat/new1-use-autoweights branch from d7e6366 to b9c79ad Compare May 15, 2025 10:44
@mergify mergify bot removed the needs-rebase label May 15, 2025
@lengrongfu lengrongfu marked this pull request as ready for review May 15, 2025 10:46
Comment on lines +411 to +407
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this won't affect too much, I think rotary_emb.inv_freq, rotary_emb.cos_cached and rotary_emb.sin_cached are not prefixes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what good suggestions do you have?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can filter out the weights to skip loading in advance before calling AutoWeightsLoader, just like what Phi-4-MM does:

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
weights = ((name, data) for name, data in weights
if "lora" not in name)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the situation you are talking about occur? Currently, I have tested that it works properly with prefix

Copy link
Member

@Isotr0py Isotr0py May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this situation seldom occurred in most cases, because only some models checkpoint fine-tuned by ColossalAI may include these tensors, and the models you tested are likely not having these tensors.

['model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.22.self_attn.rotary_emb.inv_freq', 'model.layers.24.self_attn.rotary_emb.inv_freq', ... ]

But it's still reasonable to make the weights loading more robust to include this rare case.

Copy link
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, since the RoPE buffer tensors in checkpoint is not a bug issue. Let's merge this PR and leave it to be handled with other models' modified loading logic together in a following PR.

@Isotr0py Isotr0py added the ready ONLY add when PR is ready to merge/full CI is needed label May 16, 2025
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
@lengrongfu lengrongfu force-pushed the feat/new1-use-autoweights branch from b9c79ad to 922302f Compare May 16, 2025 16:01
@Isotr0py Isotr0py enabled auto-merge (squash) May 16, 2025 16:06
@simon-mo simon-mo merged commit 9214e60 into vllm-project:main May 17, 2025
59 of 62 checks passed
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants