Skip to content

[V1][Spec Decode] KV cache slots for eagle heads #16370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 13, 2025

Conversation

LiuXiaoxuanPKU
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU commented Apr 10, 2025

Task 2 of #15901

Current change only touches the kv cache manager, the scheduler only changes its way of calling the allocate_slots.

I have not tested this PR yet, but I feel a bit hard to test, comments are appreciated. cc@WoosukKwon

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Apr 10, 2025
Comment on lines 198 to 199
num_new_tokens,
num_spec_tokens=self.num_spec_tokens)
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that num_new_tokens is needed for the verfication of the spec token ids by the target model from previous step and num_spec_tokens is the num of spec tokens that the draft model is supposed to generate at the end of this step.

Based on that, if num_new_tokens is 8 and num_spec_tokens is 4 so can end up allocating 1 block (16 tokens) such that 1 block shares both target model and draft model's KV cache?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is similar. My interpretation is that it temporarily acquires extra num_spec_tokens for draft tokens and it won't aggregate the size in the next iteration.

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the same blocks are shared by target and draft models then will it not be an issue since the KVC of target and draft model are adjacent in the logical mapping of block tables so draft model will attend to KVC of the target?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Emmm, I think it should not cause a problem as long as the actual starting kv cache slot for draft model is marked somehow?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the discussion here!

  1. num_new_tokens is for verification, num_spec_tokens is for proposing heads. "Based on that, if num_new_tokens is 8 and num_spec_tokens is 4 so can end up allocating 1 block (16 tokens) such that 1 block shares both target model and draft model's KV cache?" --> yes, exactly.
  2. KV cache corruption: currently, kv cache is allocated independently but share the same slot mapping. We can think of kv cache as a map: {layer0_kv: [], layer1_kv: []...., layerk_kv, eagle_layer_kv:[]}. During each generation step, using the example above, we first verify tokens, which will write kv to layer0_kv...layerk_kv with slot mapping [0,1,2,3,4,5,6,7]. It will not write to the draft kv. If say only 2 tokens are accepted, 3 tokens are generated. In the proposing phase, we will send the three tokens to eagle proposer with slot mapping [0,1,2], which will populate the kv cache for the generated tokens, and also propose for the next token. We allocate 12 slots (8+4) in total, because it's possible that all tokens (with slot id 0-7) are accepted, in that case, proposing tokens need to write to kv cache with id [8,9,10,11].

Let me know if there is any confusion here!

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense now. The block_table where the allocated slots get saved are shared across all layers and eagle is just a layer on top of the target model's layer. When we add blocks for num_new_tokens + num_spec_tokens then the target model will use just the num_new_tokens slots but in the case when all the drafts are accepted, draft layer will use the num_new_tokens + num_spec_tokens slots.

@@ -164,7 +164,8 @@ def allocate_slots(
self,
request: Request,
num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
num_spec_tokens: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two points to discuss:

  1. Should we use "num_lookahead_tokens" to reduce confusion? After all these slots are for the proposed tokens that will be verified in the next step.
  2. Should we consider these slots along with the preallocated blocks? Specially if preallocated blocks can cover spec tokens, then we don't need to allocate additional slots?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have seen this term lookahead_tokens before. Can you share why this is more general than spec_tokens? Is it because it can also mean jump tokens?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No jump tokens should be in new_tokens. I just feel num_spec_tokens is confusing because it actually means the spec tokens we're going to propose by the end of this step. However, we also have spec_tokens in Request, but that spec_tokens were generated by the last step for verification.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to @comaniac I have the same two questions, too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I am good with num_lookahead_tokens, will change here.
  2. yeah sure, we can do a more conservation way,
    preallocated_blocks -= num_lookahead_tokens // block_size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preallocated_blocks -= num_lookahead_tokens // block_size

We might have to revert this when num of draft tokens become large espc with tree attn since then num draft tokens ~= num preallocated tokens which would lead to frequent block allocations.

Comment on lines 218 to 220
num_required_blocks = cdiv(
num_computed_tokens + num_tokens + num_spec_tokens,
self.block_size)
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luyuzhe111 @wwl2755 - Moving the discussion of why this PR is expected to improve the AL.

I have a hypothesis. Without this PR, the queries in the draft can go out of bounds in the block_table and pick up incorrect address and value which will corrupt the answer. block_table is used in FA cuda kernels and maybe we dont check illegal memory address access.

Lets say page size is 16. This corruption will arise when have < K slots left in the last block. The preallocate block computation (extra 4 blocks) wont trigger in this case since the last block is not full. As K increases, the changes of this increases. So K=4 has higher chances of having this than K=2 which reflects here.

But then block_table is gathered here too to form the slot_mapping for queries so out of index should have given an error which it did not when using bs=1 with MTBench so I am not sure if above hypothesis is correct.

Lmk what you guys think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon @LiuXiaoxuanPKU - can you also share your insight as to why this PR is expected to increase AL?

Copy link
Contributor

@wwl2755 wwl2755 Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Is the statement "this PR can increase AL" already benchmarked OR is it set up as a goal of this PR?

Copy link
Collaborator Author

@LiuXiaoxuanPKU LiuXiaoxuanPKU Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a high level, if we don't have this PR, the current scheduler does not actually allocate slots for proposed tokens, they only allocate slots for verification. Therefore, it's not guaranteed the kv cache of the proposed heads is not contaminated.

Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LiuXiaoxuanPKU can you help us understand at a bit deeper level like which code line would be at fault?

My understanding is that If the scheduler doesn't allocate slots for the proposed tokens then torch should have thrown some error here when the new proposed tokens become the query? However, it didnt happen in our MTBench benchmark so probably there is no corruption without this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for asking! here will not trigger an error because block_table is always of a tensor of shape [batch_size, max_num_blocks_per_request], if those blocks are not allocated, the default values will be 0 in the block table.

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
@LiuXiaoxuanPKU LiuXiaoxuanPKU added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 12, 2025
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM. Also it would be good to have a simple unit test to evaluate the allocated slots.

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
num_tokens=3,
num_lookahead_tokens=4,
)
assert len(blocks) == 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a little bit weird to the common sense. When the num_lookahead_tokens increases (which means we may need more slots for allocation), the allocated blocks decrease.

In the test case 2, the num_lookahead_tokens does not use the slots for preallocate_tokens. Is there any particular reason why the design will be like this?

Copy link
Collaborator Author

@LiuXiaoxuanPKU LiuXiaoxuanPKU Apr 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah agree, I feel it's a conner case, num_lookahead_tokens does not use slots for preallocate_tokens because we calculate on the block level, and 3 // 4 = 0 (lookahead tokens borrow 0 block from preallocate tokens).

when num_lookahead_tokens=1, 2, 3, len(blocks) = 3
when num_lookahead_tokens=4, len(blocks) = 2
when num_lookahead_tokens=5, 6, 7...., len(blocks) = ceil((num_lookahead_tokens + 4 ) / 4) = 3 or bigger

number of required blocks = num of blocks required by lookahead slots + num of blocks required by compute tokens + num of preallocate blocks
num of preallocate blocks = max(0, Constant - num of blocks required by lookahead slots)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with num_lookahead_tokens=1,2,3, len(blocks) should stll be 2?

IIUC, the num_lookahead_token can used up the space for preallocate_tokens. So, basically before actually pre-allocate, we just need to check sure the last preallocate_blocks are already taken up by lookahead_tokens

May be the pseudocode could be something like:

num_required_blocks_before_lookahead = cdiv(
            num_computed_tokens + num_tokens,
            self.block_size)
num_required_blocks = cdiv(
            num_computed_tokens + num_tokens + num_lookahead_tokens,
            self.block_size)
num_required_blocks_used_by_lookahead = num_required_blocks - 
             num_required_blocks_before_lookahead

num_preallocate_blocks = max(
                0, self.num_preallocate_blocks -
                num_required_blocks_used_by_lookahead)
# if we find the preallocated blocks have been used up by lookahead, we don't need to further allocate them.

image

@LiuXiaoxuanPKU LiuXiaoxuanPKU merged commit f49e5af into vllm-project:main Apr 13, 2025
42 checks passed
@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Apr 14, 2025

I benchmarked for AL using the same setup (NOTE: yuhuili/EAGLE-LLaMA3-Instruct-8B and lmsys/sglang-EAGLE-LLaMA3-Instruct-8B have identical AL) . Combining it with the results for EAGLE official repo by @luyuzhe111 , we get this:

K=2:

  • without this PR: 1.89
  • with this PR: 1.90
  • EAGLE official repo: 2.0

K=4:

  • without this PR: 2.08
  • with this PR: 2.09
  • EAGLE official repo: 2.25

It was expected that this PR would close the gap bw vLLM and official EAGLE AL but it seems the gap is still there. Please share your thoughts on this. cc: @LiuXiaoxuanPKU @luyuzhe111 @wwl2755 @WoosukKwon

erdaltoprak pushed a commit to erdaltoprak/vllm that referenced this pull request Apr 14, 2025
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Signed-off-by: Erdal Toprak <contact@erdaltoprak.com>
@wwl2755
Copy link
Contributor

wwl2755 commented Apr 16, 2025

It was expected that this PR would close the gap bw vLLM and official EAGLE AL but it seems the gap is still there. Please share your thoughts on this. cc: @LiuXiaoxuanPKU @luyuzhe111 @wwl2755 @WoosukKwon

Then my suspection is there are some implementation flaws? I will take a closer look starting from the existing tests, including the proposing, sampling, and rejection.

@luyuzhe111
Copy link
Contributor

Hey @ekagra-ranjan @wwl2755,

I actually ran some the same AL benchmark the other day for meta-llama/Llama-3.1-8B-Instruct and yuhuili/EAGLE-LLaMA3.1-Instruct-8B instead of Llama 3.0 8B. The results are shown below. Basically, the gap between the EAGLE repo and vLLM v1 is actually small now. I haven't figured out why the gap is more pronounced for Llama 3.0 8B models though.

On MT Bench

When max number generated tokens = 256

Number of Speculated Tokens 1 2 3 4 5
EAGLE Repo 1.71 2.13 2.32 2.42 2.48
vLLM v0 1.70 2.06 2.20 2.29 2.32
vLLM v1 1.71 2.09 2.30 2.38 2.43

Chenyaaang pushed a commit to Chenyaaang/vllm that referenced this pull request Apr 16, 2025
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
@wwl2755
Copy link
Contributor

wwl2755 commented Apr 17, 2025

One thing I observed when I ran the benchmark in this PR (#16367) locally was that the results were not consistent between different runs. Sometimes K=4 gives 2.09, sometimes gives 2.10. Everything is default, temperature is 0, mt_bench is used, max_tokens is 256.

K=2:

  • without this PR: 1.89
  • with this PR: 1.90
  • EAGLE official repo: 2.0

K=4:

  • without this PR: 2.08
  • with this PR: 2.09
  • EAGLE official repo: 2.25

So, I'm thinking this 0.01 difference may not be led by this PR. That's saying, there should be gaps somewhere else.

Example command: VLLM_USE_V1=1 python examples/offline_inference/eagle.py --dataset /home/cc/vllm_benchmark_datatsets/question.jsonl --num_spec_tokens 4

cc: @ekagra-ranjan @luyuzhe111

yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
sfc-gh-mhidayetoglu added a commit to sfc-gh-mhidayetoglu/vllm that referenced this pull request Apr 29, 2025
* [Docs] Add Ollama meetup slides (vllm-project#15905)

Signed-off-by: simon-mo <simon.mo@hey.com>

* [Docs] Add Intel as Sponsor (vllm-project#15913)

Signed-off-by: simon-mo <simon.mo@hey.com>

* [Spec Decode] Fix input triton kernel for eagle (vllm-project#15909)

* [V1] Fix: make sure `k_index` is int64 for `apply_top_k_only` (vllm-project#15907)

Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>

* [Bugfix] Fix imports for MoE on CPU (vllm-project#15841)

Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>

* [V1][Minor] Enhance SpecDecoding Metrics Log in V1 (vllm-project#15902)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* [Doc] Update rocm.inc.md (vllm-project#15917)

Signed-off-by: chun37 <chun.jb.37@gmail.com>

* [V1][Bugfix] Fix typo in MoE TPU checking (vllm-project#15927)

Signed-off-by: Roger Wang <ywang@roblox.com>

* [Benchmark]Fix error message (vllm-project#15866)

Signed-off-by: wangli <wangli858794774@gmail.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>

* [Misc] Replace print with logger (vllm-project#15923)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* [CI/Build] Further clean up LoRA tests (vllm-project#15920)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>

* [Bugfix] Fix cache block size calculation for CPU MLA (vllm-project#15848)

Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>

* [Build/CI] Update lm-eval to 0.4.8 (vllm-project#15912)

Signed-off-by: Chris Thi <chris.c.thi@gmail.com>

* [Kernel] Add more dtype support for GGUF dequantization (vllm-project#15879)

Signed-off-by: lukas.bluebaum <lukas.bluebaum@aleph-alpha.com>

* [core] Add tags parameter to wake_up() (vllm-project#15500)

Signed-off-by: Eric <erictang000@gmail.com>

* [V1] Fix json_object support with xgrammar (vllm-project#15488)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* Add minimum version for `huggingface_hub` to enable Xet downloads (vllm-project#15873)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Bugfix][Benchmarks] Ensure `async_request_deepspeed_mii` uses the OpenAI choices key (vllm-project#15926)

Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>

* [CI] Remove duplicate entrypoints-test (vllm-project#15940)

Signed-off-by: Kay Yan <kay.yan@daocloud.io>

* [Bugfix] Fix the issue where the model name is empty string, causing no response with the model name. (vllm-project#15938)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* [Metrics] Hide deprecated metrics (vllm-project#15458)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>

* [Frontend] Implement Tool Calling with `tool_choice='required'` (vllm-project#13483)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>
Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at>
Co-authored-by: Liangfu Chen <liangfc@amazon.com>
Co-authored-by: mgoin <michael@neuralmagic.com>

* [CPU][Bugfix] Using custom allreduce for CPU backend (vllm-project#15934)

Signed-off-by: jiang1.li <jiang1.li@intel.com>

* [Model] use AutoWeightsLoader in model load_weights (vllm-project#15770)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>

* [Misc] V1 LoRA support CPU offload (vllm-project#15843)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>

* Restricted cmake to be less than version 4 as 4.x breaks the build of… (vllm-project#15859)

Signed-off-by: Nishidha Panpaliya <nishidha.panpaliya@partner.ibm.com>

* [misc] instruct pytorch to use nvml-based cuda check (vllm-project#15951)

Signed-off-by: youkaichao <youkaichao@gmail.com>

* [V1] Support Mistral3 in V1 (vllm-project#15950)

Signed-off-by: mgoin <mgoin64@gmail.com>

* Fix `huggingface-cli[hf-xet]` -> `huggingface-cli[hf_xet]` (vllm-project#15969)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [V1][TPU] TPU-optimized top-p implementation (avoids scattering). (vllm-project#15736)

Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>

* [TPU] optimize the all-reduce performance (vllm-project#15903)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* [V1][TPU] Do not compile sampling more than needed (vllm-project#15883)

Signed-off-by: NickLucche <nlucches@redhat.com>

* [ROCM][KERNEL] Paged attention for V1 (vllm-project#15720)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: root <root@banff-cyxtera-s65-4.amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: root <root@banff-cyxtera-s65-4.amd.com>

* fix: better error message for get_config close vllm-project#13889 (vllm-project#15943)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [bugfix] add seed in torchrun_example.py (vllm-project#15980)

Signed-off-by: youkaichao <youkaichao@gmail.com>

* [ROCM][V0] PA kennel selection when no sliding window provided (vllm-project#15982)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>

* [Benchmark] Add AIMO Dataset to Benchmark (vllm-project#15955)

Signed-off-by: Ziji Shi <shi.ziji.sm@gmail.com>
Signed-off-by: StevenShi-23 <shi.ziji.sm@gmail.com>

* [misc] improve error message for "Failed to infer device type" (vllm-project#15994)

Signed-off-by: youkaichao <youkaichao@gmail.com>

* [Bugfix][V1] Fix bug from putting llm_engine.model_executor in a background process (vllm-project#15367)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>

* [doc] update contribution link (vllm-project#15922)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* fix: tiny fix make format.sh excutable (vllm-project#16015)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [SupportsQuant] Bert, Blip, Blip2, Bloom (vllm-project#15573)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* [SupportsQuant] Chameleon, Chatglm, Commandr (vllm-project#15952)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* [Neuron][kernel] Fuse kv cache into a single tensor (vllm-project#15911)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>

* [Minor] Fused experts refactor (vllm-project#15914)

Signed-off-by: Bill Nell <bnell@redhat.com>

* [Misc][Performance] Advance tpu.txt to the most recent nightly torch … (vllm-project#16024)

* Re-enable the AMD Testing for the passing tests. (vllm-project#15586)

Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com>

* [TPU] Support sliding window and logit soft capping in the paged attention kernel for TPU. (vllm-project#15732)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>

* [TPU] Switch Test to Non-Sliding Window (vllm-project#15981)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>

* [Bugfix] Fix function names in test_block_fp8.py (vllm-project#16033)

Signed-off-by: Bill Nell <bnell@redhat.com>

* [ROCm] Tweak the benchmark script to run on ROCm (vllm-project#14252)

* [Misc] improve gguf check (vllm-project#15974)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [TPU][V1] Remove ragged attention kernel parameter hard coding (vllm-project#16041)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* doc: add info for macos clang errors (vllm-project#16049)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [V1][Spec Decode] Avoid logging useless nan metrics (vllm-project#16023)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>

* [Model] use AutoWeightsLoader for baichuan, gpt-neox, mpt (vllm-project#15939)

Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>

* [Hardware][Gaudi][BugFix] fix arguments of hpu fused moe (vllm-project#15945)

Signed-off-by: zhenwei <zhenweiliu@habana.ai>

* [Bugfix][kernels] Fix half2float conversion in gguf kernels (vllm-project#15995)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [Benchmark][Doc] Update throughput benchmark and README (vllm-project#15998)

Signed-off-by: StevenShi-23 <shi.ziji.sm@gmail.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <ywang@roblox.com>

* [CPU] Change default block_size for CPU backend (vllm-project#16002)

Signed-off-by: jiang1.li <jiang1.li@intel.com>

* [Distributed] [ROCM] Fix custom allreduce enable checks (vllm-project#16010)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>

* [ROCm][Bugfix] Use platform specific FP8 dtype (vllm-project#15717)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

* [ROCm][Bugfix] Bring back fallback to eager mode removed in vllm-project#14917, but for ROCm only (vllm-project#15413)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

* [Bugfix] Fix default behavior/fallback for pp in v1 (vllm-project#16057)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [CI] Reorganize .buildkite directory (vllm-project#16001)

Signed-off-by: kevin <kevin@anyscale.com>

* [V1] DP scale-out (1/N): Use zmq ROUTER/DEALER sockets for input queue (vllm-project#15906)

Signed-off-by: Nick Hill <nhill@redhat.com>

* [V1] Scatter and gather placeholders in the model runner (vllm-project#15712)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>

* Revert "[V1] Scatter and gather placeholders in the model runner" (vllm-project#16075)

* [Kernel][Minor] Re-fuse triton moe weight application (vllm-project#16071)

Signed-off-by: Bill Nell <bnell@redhat.com>

* [Bugfix][TPU] Fix V1 TPU worker for sliding window (vllm-project#16059)

Signed-off-by: Michael Goin <mgoin64@gmail.com>

* [V1][Spec Decode] Update N-gram Proposer Interface (vllm-project#15750)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* [Misc] Auto detect bitsandbytes pre-quantized models (vllm-project#16027)

Signed-off-by: Tristan Leclercq <tristanleclercq@gmail.com>

* [CI] Fix benchmark script level (vllm-project#16089)

* fix: support clang17 for macos and fix the real libomp (vllm-project#16086)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [doc] fix 404 (vllm-project#16082)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* Revert "doc: add info for macos clang errors (vllm-project#16049)" (vllm-project#16091)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* Fix some capitalisations in generated examples doc titles (vllm-project#16094)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Misc] format output for encoder_decoder.py (vllm-project#16095)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Misc] Remove redundant code (vllm-project#16098)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* [Bugfix] fix use_atomic_add support of marlin kernel when using v1 engine (vllm-project#15946)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>

* [Model] use AutoWeightsLoader for phi, gemma, deepseek (vllm-project#16088)

Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>

* [Model] fix model testing for TeleChat2ForCausalLM and V0 llama4 (vllm-project#16112)

Signed-off-by: Lu Fang <fanglu@fb.com>

* [Benchmark] Add sampling parameters to benchmark_serving. (vllm-project#16022)

Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>

* [Frontend] Fix typo in tool chat templates for llama3.2 and toolace (vllm-project#14501)

Signed-off-by: Ben Jackson <ben@ben.com>

* [CI][V1] Fix passing `tokenizer` as kwarg to `validate_guidance_grammar` (vllm-project#16117)

Signed-off-by: Roger Wang <ywang@roblox.com>

* [Misc] refactor example eagle (vllm-project#16100)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Doc][Bugfix] Add missing EOF in k8s deploy doc (vllm-project#16025)

* [Misc] Improve model redirect to accept json dictionary (vllm-project#16119)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [Model] use AutoWeightsLoader for stablelm,starcoder2,zamba2 (vllm-project#16103)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>

* [Bugfix] LoRA : Fix the order in which the kernels process LoRAs  (vllm-project#16040)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>

* [Bugfix] add hf_token to EngineArgs (vllm-project#16093)

Signed-off-by: paolovic <paul-philipp.luley@uzh.ch>
Co-authored-by: paolovic <paul-philipp.luley@uzh.ch>

* [Misc] update requires-python in pyproject.toml (vllm-project#16116)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [TPU] Update PyTorch/XLA (vllm-project#16130)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* [V1][Minor] Optimize get_cached_block (vllm-project#16135)

* Fix requires-python (vllm-project#16132)

* [Metrics] Add bucket for `request_latency`, `time_to_first_token` and `time_per_output_token` (vllm-project#15202)

Signed-off-by: Kay Yan <kay.yan@daocloud.io>

* [V1][Minor] Minor simplification for get_computed_blocks  (vllm-project#16139)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* [Misc] Update Mistral-3.1 example (vllm-project#16147)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Make dummy encoder prompt padding alternative and add missing warnings (vllm-project#16129)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [CI] Set max transformers version for Ultravox model test  (vllm-project#16149)

Signed-off-by: Roger Wang <ywang@roblox.com>

* doc: fix some typos in doc (vllm-project#16154)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [VLM] Florence-2 supports online serving (vllm-project#16164)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [V1][Structured Output] Add `supports_structured_output()` method to Platform (vllm-project#16148)

Signed-off-by: shen-shanshan <467638484@qq.com>

* [Model] Add Qwen3 and Qwen3MoE (vllm-project#15289)

Signed-off-by: YamPengLi <yampayne.lyp@alibaba-inc.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* [Misc] improve example mlpspeculator and llm_engine_example (vllm-project#16175)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Doc]Update image to latest version (vllm-project#16186)

Signed-off-by: WangErXiao <863579016@qq.com>

* Upstream Llama4 Support to Main (vllm-project#16113)

Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com>
Signed-off-by: Chris Thi <chris.c.thi@gmail.com>
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: Jon Swenson <jmswen@gmail.com>
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Xiaodong Wang <xdwang@meta.com>
Signed-off-by: Yang Chen <yangche@fb.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Re-enable support for `ChatGLMForConditionalGeneration` (vllm-project#16187)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [V1] Revert the default `max_num_seqs` to V0 values for most hardware (vllm-project#16158)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* Print the warning only once (vllm-project#16193)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

* [Misc] Human-readable `max-model-len` cli arg (vllm-project#16181)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>

* [Misc] Move Llama 4 projector call into encoder execution (vllm-project#16201)

* [Bugfix] Fix guidance backend for Qwen models (vllm-project#16210)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>

* [V1][BugFix] Exit properly if engine core fails during startup (vllm-project#16137)

Signed-off-by: Nick Hill <nhill@redhat.com>

* [Misc] add description attribute in CLI (vllm-project#15921)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Bugfix][V0] XGrammar structured output supports Enum (vllm-project#15878)

Signed-off-by: Leon Seidel <leon.seidel@fau.de>

* Torchao (vllm-project#14231)

Signed-off-by: drisspg <drisspguessous@gmail.com>

* [ROCm][Bugfix][FP8] Make fp8 quant respect fused modules mapping (vllm-project#16031)

Signed-off-by: mgoin <michael@neuralmagic.com>

* [core] do not send error across process (vllm-project#16174)

Signed-off-by: youkaichao <youkaichao@gmail.com>

* [Misc] Update compressed-tensors to version 0.9.3 (vllm-project#16196)

Signed-off-by: Miles Williams <42222518+mlsw@users.noreply.github.com>

* Update BASE_IMAGE to 2.22 release of Neuron (vllm-project#16218)

* [V1] Scatter and gather placeholders in the model runner (vllm-project#16076)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Jennifer Zhao <ai.jenniferzhao@gmail.com>

* [Bugfix] fix use-ep bug to enable ep by dp/tp size > 1 (vllm-project#16161)

* Add warning for Attention backends that do not support irope yet (vllm-project#16212)

* [Bugfix] Do not skip "empty" parts of chats that are parsable (vllm-project#16219)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Bugfix] Fix and reorganize broken GGUF tests and bump gguf version (vllm-project#16194)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [torch.compile][TPU] Make @support_torch_compile work for XLA backend (vllm-project#15782)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>

* [V1] Add `disable_chunked_mm_input` arg to disable partial mm input prefill (vllm-project#15837)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Misc] Merge the logs of pp layers partitions (vllm-project#16225)

Signed-off-by: Kebe <mail@kebe7jun.com>

* [Docs] Add Slides from Singapore Meetup (vllm-project#16213)

Signed-off-by: simon-mo <simon.mo@hey.com>

* [Misc] format and refactor some examples (vllm-project#16252)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Misc] Add warning for multimodal data in LLM.beam_search (vllm-project#16241)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* [Model] use AutoWeightsLoader for phimoe,qwen2_moe,qwen3_moe (vllm-project#16203)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>

* [BugFix][ROCm] Fix GGUF MoE Dispatch Block_Dim for ROCm (vllm-project#16247)

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* [Bugfix] Remove triton do_bench fast_flush arg (vllm-project#16256)

Signed-off-by: Kebe <mail@kebe7jun.com>

* Update to transformers==4.51.1 (vllm-project#16257)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [New Model]: jinaai/jina-embeddings-v3 (vllm-project#16120)

* [Misc] Avoid stripping meaningful whitespace from `nvidia-smi topo -m` output in collect_env.py (vllm-project#16272)

Signed-off-by: imkero <kerorek@outlook.com>

* [Bugfix] Proper input validation for multi-modal encoder-decoder models (vllm-project#16156)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Handle `process_weights_after_loading` for `QKVCrossParallelLinear` (vllm-project#15328)

Signed-off-by: Isotr0py <2037008807@qq.com>

* Add warning that content below line in template will be removed (vllm-project#16276)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [BugFix] Fix Llama4 - Index Error When Single Request Near Max Context (vllm-project#16209)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* [Bugfix] fix deepseek fp16 scale bug (vllm-project#14809)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>

* [V1] Update structured output offline inference example (vllm-project#15721)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* [CI/Build] Fix CI LoRA failure (vllm-project#16270)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>

* Add support to modelopt quantization of Mixtral model (vllm-project#15961)

Signed-off-by: Yue <yueshen@nvidia.com>

* [Model] Add smolvlm support (vllm-project#16017)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* [Bug] [ROCm] Fix Llama 4 Enablement Bug on ROCm: V0 ROCmFlashAttentionImpl and Triton Fused MoE bugs (vllm-project#16198)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Co-authored-by: Hongxia Yang <hongxia.yang@amd.com>
Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com>

* [Bugfix] fix gettid method is not define (vllm-project#16084)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>

* [Feature] Estimate max-model-len use available KV cache memory (vllm-project#16168)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>

* [Core] Upgrade to xgrammar 0.1.18, add cache size limit (vllm-project#16283)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* [CI][Bugfix] Fix bad tolerance for test_batch_base64_embedding (vllm-project#16221)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [TPU] Update PyTorch/XLA (vllm-project#16288)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* [BugFix] Fix fusion test and add them to CI (vllm-project#16287)

Signed-off-by: luka <luka@neuralmagic.com>

* [Misc] Fix test_sharded_state_loader.py(vllm-project#16004) (vllm-project#16005)

Signed-off-by: lvfei.lv <lvfei.lv@alibaba-inc.com>

* [Bugfix] Avoid transferring cached multi-modal items from P0 to P1 (vllm-project#16273)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* Update label-tpu mergify and remove removal bot (vllm-project#16298)

* [BugFix] logger is not callable (vllm-project#16312)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [BugFix] llama4 qknorm should be not shared across head (vllm-project#16311)

Signed-off-by: Lu Fang <fanglu@fb.com>

* update neuron config (vllm-project#16289)

Signed-off-by: Ajay Vohra <ajayvohr@amazon.com>

* [BugFix] fix some typos found by typos. (vllm-project#16314)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [Model] Add `SupportsMultiModal.get_language_model` interface (vllm-project#16007)

Signed-off-by: NickLucche <nlucches@redhat.com>

* [Bugfix][Frontend] respect provided default guided decoding backend (vllm-project#15476)

Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>

* Revert "Update label-tpu mergify and remove removal bot" (vllm-project#16350)

* [Bugfix] Fix profiling.py (vllm-project#16202)

Signed-off-by: zh Wang <rekind133@outlook.com>

* [Bugfix] catch AssertionError in MistralTokenizer as ValueError (vllm-project#16344)

Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>

* [CI]Fix hpu docker and numpy version for CI (vllm-project#16355)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>

* Fix `benchmark_throughput.py --backend=hf` (vllm-project#16352)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Build/CI] Add tracing deps to vllm container image (vllm-project#15224)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* [Hardware] add platform-specific request validation api (vllm-project#16291)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>

* [Misc] refactor Structured Outputs example (vllm-project#16322)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [TPU][V1] Refine tpu_model_runner to mitigate future recompilation issues (vllm-project#16275)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* Add GLM-4-0414 support (vllm-project#16338)

Signed-off-by: lvfei.lv <lvfei.lv@alibaba-inc.com>
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Ajay Vohra <ajayvohr@amazon.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>
Co-authored-by: Accelerator1996 <lvfei.lv@alibaba-inc.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: yihong <zouzou0208@gmail.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Co-authored-by: ajayvohra2005 <ajayvohr@amazon.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Guillaume Calmettes <gcalmettes@scaleway.com>

* [Bugfix]: do not shutdown server if `skip_special_use=False` for MistralTokenizer (vllm-project#14094)

Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>

* [Model] use AutoWeightsLoader for granite, granitemoe, granitemoeshared, grok1, mixtral (vllm-project#16325)

Signed-off-by: Aaron Ang <aaron.angyd@gmail.com>

* [TPU] Fix dummy loading OOM (vllm-project#16372)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* [bugfix] Avoid the time consumption caused by creating dummy videos. (vllm-project#16371)

* [CI][Bugfix] Pin triton version for CPU (vllm-project#16384)

Signed-off-by: Roger Wang <ywang@roblox.com>

* [misc] use tqdm.auto where appropriate (vllm-project#16290)

Signed-off-by: Benjamin Kitor <bkitor@gigaio.com>

* [Bugfix][TPU] Fix TPU validate_request (vllm-project#16369)

Signed-off-by: Michael Goin <mgoin64@gmail.com>

* fix sonnet dataset sample when prefix len is very small (vllm-project#16379)

Signed-off-by: Chenyaaang <chenyangli@google.com>

* [Model] use AutoWeightsLoader for deepseek_v2, internlm2 (vllm-project#16383)

Signed-off-by: Aaron Ang <aaron.angyd@gmail.com>

* [Misc] Update transformers version limits of multi-modal tests (vllm-project#16381)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Fix validation error for text-only Mllama 3.2 (vllm-project#16377)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Kernel] Use moe_wna16 kernel for compressed tensors wna16 moe models (vllm-project#16038)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [doc] add download model tips (vllm-project#16389)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* Update Numba to 0.61.2 (vllm-project#16376)

Signed-off-by: cyy <cyyever@outlook.com>

* [Model] Remove image mm limit for LLaMa4  (vllm-project#16365)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>

* [doc] update the wrong link (vllm-project#16401)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [CI] Add auto update workflow for Dockerfile graph (vllm-project#11879)

Signed-off-by: wineandchord <guoqizhou19@gmail.com>

* Fix the torch version parsing logic (vllm-project#15857)

* [VLM] Remove `BaseProcessingInfo.get_mm_max_tokens_per_item` (vllm-project#16408)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [TPU][V1] Use `language_model` interface for getting text backbone in MM (vllm-project#16410)

Signed-off-by: NickLucche <nlucches@redhat.com>

* Improve configs - `ParallelConfig` (vllm-project#16332)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [V1] Set structured output backend to `auto` by default (vllm-project#15724)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* [V1][Spec Decode] Eagle Model loading (vllm-project#16035)

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>

* [Bugfix] Fix bug when dataset is json (vllm-project#15899)

Signed-off-by: Chenyaaang <chenyangli@google.com>

* [Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (vllm-project#15423)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* [V1] Zero-copy tensor/ndarray serialization/transmission (vllm-project#13790)

Signed-off-by: Nick Hill <nhill@redhat.com>

* [VLM] Avoid unnecessary dummy multimodal data during processing (vllm-project#16416)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Fix output token length check logic (vllm-project#16419)

Signed-off-by: look <eeslook@163.com>

* [TPU][V1] Disable per-request seed/Generator (vllm-project#16172)

Signed-off-by: NickLucche <nlucches@redhat.com>

* Fix range_ratio Bug in RandomDataset (vllm-project#16126)

Signed-off-by: jadewang21 <jadewangcn@outlook.com>

* check input length of sonnet samples (vllm-project#16423)

Signed-off-by: alexey-belyakov <alexey.belyakov@intel.com>

* update benchmark_serving_structured_output to include auto backend (vllm-project#16438)

Signed-off-by: Chenyaaang <chenyangli@google.com>

* [Llama4] Enable attention temperature tuning by default for long context (>32k) (vllm-project#16439)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>

* Update supported_hardware.md for TPU INT8 (vllm-project#16437)

* [Bugfix][VLM] Fix failing Phi-4-MM multi-images tests and add vision-speech test (vllm-project#16424)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [CPU][Bugfix] Fix CPU docker issues (vllm-project#16454)

Signed-off-by: jiang.li <jiang1.li@intel.com>

* [Bugfix] Don't set an upper bound on repetition penalty (vllm-project#16403)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Nick Hill <nhill@redhat.com>

* Revert "[Model] use AutoWeightsLoader for deepseek_v2, internlm2" (vllm-project#16453)

* [Core][LoRA][1/N] Add LoRA for EncoderDecoderModelRunner (vllm-project#15990)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>

* Enforce valid max_num_batched_tokens when disable_chunked_mm_input=True (vllm-project#16447)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Misc] Raise error for V1 not supporting Long LoRA. (vllm-project#16415)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>

* [Misc] update api_client example (vllm-project#16459)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* Don't install triton on `ppc64le` platform (vllm-project#16470)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Kernel] support merge_attn_states CUDA kernel, 3x speedup (vllm-project#16173)

Signed-off-by: DefTruth <qiustudent_r@163.com>

* [Bugfix] Fix bugs of running Quark quantized models (vllm-project#16236)

Signed-off-by: chaow <chaow@amd.com>

* [Hardware][Intel-Gaudi] Multi-step scheduling implementation for HPU (vllm-project#12779)

Signed-off-by: Tomasz Zielinski <tomasz.zielinski@intel.com>

* Fix erroneous "model doesn't support compile" warning (vllm-project#16486)

Signed-off-by: rzou <zou3519@gmail.com>

* [TPU][V1] Make `--disable_chunked_mm_input` mandatory for serving MM models (vllm-project#16483)

Signed-off-by: NickLucche <nlucches@redhat.com>

* [Kernel] Support W8A8 channel-wise weights and per-token activations in triton fused_moe_kernel (vllm-project#16366)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Doc] Document InternVL3 support (vllm-project#16495)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [Bugfix] handle alignment of encoder_seq_lens in mllama.py (vllm-project#14784)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>

* Improve configs - `LoadConfig` (vllm-project#16422)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Frontend] Added chat templates for LLaMa4 pythonic tool calling (vllm-project#16463)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Kai Wu <kaiwu@meta.com>

* [Kernel] Add tuned FusedMoE kernel config for Llama4 Scout, TP=8 on H100  (vllm-project#16488)

* Update openai_compatible_server.md (vllm-project#16507)

Signed-off-by: Christian Sears <csears@redhat.com>

* [Bugfix] clean up duplicated code (vllm-project#16485)

Signed-off-by: Gogs <gogs@fake.local>
Co-authored-by: Gogs <gogs@fake.local>

* Bugfix for PixtralHF models without spatial_merge_size (vllm-project#16513)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Doc] Fix link to vLLM blog (vllm-project#16519)

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>

* [CI][Bugfix] Add mistral_tool_use to Ci (vllm-project#16517)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [BugFix] Handle non-contiguous tensors properly when serializing (vllm-project#16492)

Signed-off-by: Nick Hill <nhill@redhat.com>

* [Doc] Update Llama4 Model Names in Supported Models (vllm-project#16509)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>

* Optimized topk for topk=1 (Llama-4) (vllm-project#16512)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Feature][V1] Add xgrammar to support minLength, maxLength with test (vllm-project#16516)

Signed-off-by: Leon Seidel <leon.seidel@fau.de>

* [Frontend] support matryoshka representation / support embedding API dimensions (vllm-project#16331)

* fix: spelling (vllm-project#16466)

Signed-off-by: Tianer Zhou <ezhoureal@gmail.com>

* [Misc] Update chat utils tests (vllm-project#16520)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Misc] Openai transcription client example use same Whisper model (vllm-project#16487)

Signed-off-by: NickLucche <nlucches@redhat.com>

* [V1] Enable multi-input by default (vllm-project#15799)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [MISC] Make GroupCoordinator compatible with out-of-tree devices (vllm-project#16464)

Signed-off-by: hzji210@gmail.com <hzji210@gmail.com>

* [Misc] Delete redundant code (vllm-project#16530)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* Fix syntaxWarning: invalid escape sequence '\s' (vllm-project#16532)

Signed-off-by: Jie Fu <jiefu@tencent.com>

* [Perf] Optimize Preparing Inputs for GPU Model Runner (vllm-project#16484)

Signed-off-by: snowcharm <snowcharmqq@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>

* [Bugfix] Validate logit biases to prevent out of vocab ids crashing engine (vllm-project#16529)

Signed-off-by: Ryan McConville <ryan@ryanmcconville.com>

* [V1][Spec Decode] KV cache slots for eagle heads (vllm-project#16370)

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>

* Enable PTPC FP8 for CompressedTensorsW8A8Fp8MoEMethod (triton fused_moe) (vllm-project#16537)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Benchmark][Bugfix] Fix SonnetDataset default values in benchmark_throughput.py (vllm-project#16556)

* [Core][V0] Enable regex support with xgrammar (vllm-project#13228)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: simon-mo <simon.mo@hey.com>
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: chun37 <chun.jb.37@gmail.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Chris Thi <chris.c.thi@gmail.com>
Signed-off-by: lukas.bluebaum <lukas.bluebaum@aleph-alpha.com>
Signed-off-by: Eric <erictang000@gmail.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Kay Yan <kay.yan@daocloud.io>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: Liangfu Chen <liangfc@amazon.com>
Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
Signed-off-by: Nishidha Panpaliya <nishidha.panpaliya@partner.ibm.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: root <root@banff-cyxtera-s65-4.amd.com>
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: Ziji Shi <shi.ziji.sm@gmail.com>
Signed-off-by: StevenShi-23 <shi.ziji.sm@gmail.com>
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: reidliu41 <reid201711@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>
Signed-off-by: zhenwei <zhenweiliu@habana.ai>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: kevin <kevin@anyscale.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Tristan Leclercq <tristanleclercq@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Ben Jackson <ben@ben.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: paolovic <paul-philipp.luley@uzh.ch>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: YamPengLi <yampayne.lyp@alibaba-inc.com>
Signed-off-by: WangErXiao <863579016@qq.com>
Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com>
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: Jon Swenson <jmswen@gmail.com>
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Xiaodong Wang <xdwang@meta.com>
Signed-off-by: Yang Chen <yangche@fb.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Leon Seidel <leon.seidel@fau.de>
Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: Miles Williams <42222518+mlsw@users.noreply.github.com>
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Signed-off-by: Kebe <mail@kebe7jun.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
Signed-off-by: imkero <kerorek@outlook.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Yue <yueshen@nvidia.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: lvfei.lv <lvfei.lv@alibaba-inc.com>
Signed-off-by: Ajay Vohra <ajayvohr@amazon.com>
Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>
Signed-off-by: zh Wang <rekind133@outlook.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: Aaron Ang <aaron.angyd@gmail.com>
Signed-off-by: Benjamin Kitor <bkitor@gigaio.com>
Signed-off-by: Chenyaaang <chenyangli@google.com>
Signed-off-by: cyy <cyyever@outlook.com>
Signed-off-by: wineandchord <guoqizhou19@gmail.com>
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: look <eeslook@163.com>
Signed-off-by: jadewang21 <jadewangcn@outlook.com>
Signed-off-by: alexey-belyakov <alexey.belyakov@intel.com>
Signed-off-by: jiang.li <jiang1.li@intel.com>
Signed-off-by: DefTruth <qiustudent_r@163.com>
Signed-off-by: chaow <chaow@amd.com>
Signed-off-by: Tomasz Zielinski <tomasz.zielinski@intel.com>
Signed-off-by: rzou <zou3519@gmail.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Christian Sears <csears@redhat.com>
Signed-off-by: Gogs <gogs@fake.local>
Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
Signed-off-by: Tianer Zhou <ezhoureal@gmail.com>
Signed-off-by: hzji210@gmail.com <hzji210@gmail.com>
Signed-off-by: Jie Fu <jiefu@tencent.com>
Signed-off-by: snowcharm <snowcharmqq@gmail.com>
Signed-off-by: Ryan McConville <ryan@ryanmcconville.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: chun <chun.jb.37@gmail.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Chris Thi <chris.c.thi@gmail.com>
Co-authored-by: LukasBluebaum <38468743+LukasBluebaum@users.noreply.github.com>
Co-authored-by: Eric Tang <46737979+erictang000@users.noreply.github.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Kay Yan <kay.yan@daocloud.io>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Matthias Matt <37695050+meffmadd@users.noreply.github.com>
Co-authored-by: Liangfu Chen <liangfc@amazon.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Li, Jiang <jiang1.li@intel.com>
Co-authored-by: rongfu.leng <lenronfu@gmail.com>
Co-authored-by: Nishidha <nishidha.panpaliya@partner.ibm.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Hyesoo Yang <45211235+hyeygit@users.noreply.github.com>
Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: root <root@banff-cyxtera-s65-4.amd.com>
Co-authored-by: yihong <zouzou0208@gmail.com>
Co-authored-by: Ziji Shi (Steven) <shi.ziji.sm@gmail.com>
Co-authored-by: wwl2755 <wangwenlong2755@gmail.com>
Co-authored-by: Reid <61492567+reidliu41@users.noreply.github.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com>
Co-authored-by: yarongmu-google <150371854+yarongmu-google@users.noreply.github.com>
Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com>
Co-authored-by: iefgnoix <isaacwxf23@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Huy Do <huydhn@gmail.com>
Co-authored-by: Jonghyun Choe <andy.choe729@gmail.com>
Co-authored-by: liuzhenwei <zhenweiliu@habana.ai>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Ilya Markov <markovilya197@gmail.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: Kevin H. Luu <kevin@anyscale.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Tristan Leclercq <49700633+tristanleclercq@users.noreply.github.com>
Co-authored-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Co-authored-by: Ben Jackson <ben@ben.com>
Co-authored-by: Paul Schweigert <paul@paulschweigert.com>
Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io>
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: paolovic <91155454+paolovic@users.noreply.github.com>
Co-authored-by: paolovic <paul-philipp.luley@uzh.ch>
Co-authored-by: Martin Hoyer <mhoyer@redhat.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: YamPengLi <yampayne.lyp@alibaba-inc.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Robin <863579016@qq.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
Co-authored-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Co-authored-by: leon-seidel <83984854+leon-seidel@users.noreply.github.com>
Co-authored-by: Driss Guessous <32754868+drisspg@users.noreply.github.com>
Co-authored-by: Miles Williams <42222518+mlsw@users.noreply.github.com>
Co-authored-by: Satyajith Chilappagari <satchill@amazon.com>
Co-authored-by: Jennifer Zhao <ai.jenniferzhao@gmail.com>
Co-authored-by: zxfan-cpu <zxfanzhang@tencent.com>
Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com>
Co-authored-by: Siyuan Liu <lsiyuan@google.com>
Co-authored-by: Kebe <mail@kebe7jun.com>
Co-authored-by: Alex Brooks <alex.brooks@ibm.com>
Co-authored-by: TY-AMD <tianyuan.wu@amd.com>
Co-authored-by: wang.yuqi <noooop@126.com>
Co-authored-by: Kero Liang <kerorek@outlook.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: yueshen2016 <39203804+yueshen2016@users.noreply.github.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
Co-authored-by: Hongxia Yang <hongxia.yang@amd.com>
Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Accelerator1996 <lvfei.lv@alibaba-inc.com>
Co-authored-by: ajayvohra2005 <ajayvohr@amazon.com>
Co-authored-by: Guillaume Calmettes <gcalmettes@scaleway.com>
Co-authored-by: zh Wang <rekind133@outlook.com>
Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Yuxuan Zhang <2448370773@qq.com>
Co-authored-by: Aaron Ang <67321817+aaron-ang@users.noreply.github.com>
Co-authored-by: Jintao <huangjintao@mail.ustc.edu.cn>
Co-authored-by: Benjamin Kitor <bkitor@gigaio.com>
Co-authored-by: Chenyaaang <42742451+Chenyaaang@users.noreply.github.com>
Co-authored-by: cyyever <cyyever@outlook.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: wineandchord <guoqizhou123123@qq.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Co-authored-by: Lily Liu <lilyliupku@gmail.com>
Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Co-authored-by: look <eeslook@163.com>
Co-authored-by: WWW <jadewangcn@outlook.com>
Co-authored-by: Alexey Belyakov <alexey.belyakov@intel.com>
Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
Co-authored-by: chaow-amd <chaow@amd.com>
Co-authored-by: Tomasz Zielinski <85164140+tzielinski-habana@users.noreply.github.com>
Co-authored-by: Richard Zou <zou3519@users.noreply.github.com>
Co-authored-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Kai Wu <kaiwu@meta.com>
Co-authored-by: Christian Sears <117944059+Chr1st1anSears@users.noreply.github.com>
Co-authored-by: Gogs <gogs@fake.local>
Co-authored-by: Yuan Tang <terrytangyuan@gmail.com>
Co-authored-by: Tianer Zhou <ezhoureal@gmail.com>
Co-authored-by: Huazhong Ji <hzji210@gmail.com>
Co-authored-by: Jie Fu (傅杰) <jiefu@tencent.com>
Co-authored-by: SnowCharm <qiuyilun@u.nus.edu>
Co-authored-by: Ryan McConville <ryan@ryanmcconville.com>
sfc-gh-mhidayetoglu added a commit to sfc-gh-mhidayetoglu/vllm that referenced this pull request May 1, 2025
* [V1] Fix: make sure `k_index` is int64 for `apply_top_k_only` (vllm-project#15907)

Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>

* [Bugfix] Fix imports for MoE on CPU (vllm-project#15841)

Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>

* [V1][Minor] Enhance SpecDecoding Metrics Log in V1 (vllm-project#15902)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* [Doc] Update rocm.inc.md (vllm-project#15917)

Signed-off-by: chun37 <chun.jb.37@gmail.com>

* [V1][Bugfix] Fix typo in MoE TPU checking (vllm-project#15927)

Signed-off-by: Roger Wang <ywang@roblox.com>

* [Benchmark]Fix error message (vllm-project#15866)

Signed-off-by: wangli <wangli858794774@gmail.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>

* [Misc] Replace print with logger (vllm-project#15923)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* [CI/Build] Further clean up LoRA tests (vllm-project#15920)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>

* [Bugfix] Fix cache block size calculation for CPU MLA (vllm-project#15848)

Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>

* [Build/CI] Update lm-eval to 0.4.8 (vllm-project#15912)

Signed-off-by: Chris Thi <chris.c.thi@gmail.com>

* [Kernel] Add more dtype support for GGUF dequantization (vllm-project#15879)

Signed-off-by: lukas.bluebaum <lukas.bluebaum@aleph-alpha.com>

* [core] Add tags parameter to wake_up() (vllm-project#15500)

Signed-off-by: Eric <erictang000@gmail.com>

* [V1] Fix json_object support with xgrammar (vllm-project#15488)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* Add minimum version for `huggingface_hub` to enable Xet downloads (vllm-project#15873)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Bugfix][Benchmarks] Ensure `async_request_deepspeed_mii` uses the OpenAI choices key (vllm-project#15926)

Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>

* [CI] Remove duplicate entrypoints-test (vllm-project#15940)

Signed-off-by: Kay Yan <kay.yan@daocloud.io>

* [Bugfix] Fix the issue where the model name is empty string, causing no response with the model name. (vllm-project#15938)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* [Metrics] Hide deprecated metrics (vllm-project#15458)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>

* [Frontend] Implement Tool Calling with `tool_choice='required'` (vllm-project#13483)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>
Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at>
Co-authored-by: Liangfu Chen <liangfc@amazon.com>
Co-authored-by: mgoin <michael@neuralmagic.com>

* [CPU][Bugfix] Using custom allreduce for CPU backend (vllm-project#15934)

Signed-off-by: jiang1.li <jiang1.li@intel.com>

* [Model] use AutoWeightsLoader in model load_weights (vllm-project#15770)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>

* [Misc] V1 LoRA support CPU offload (vllm-project#15843)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>

* Restricted cmake to be less than version 4 as 4.x breaks the build of… (vllm-project#15859)

Signed-off-by: Nishidha Panpaliya <nishidha.panpaliya@partner.ibm.com>

* [misc] instruct pytorch to use nvml-based cuda check (vllm-project#15951)

Signed-off-by: youkaichao <youkaichao@gmail.com>

* [V1] Support Mistral3 in V1 (vllm-project#15950)

Signed-off-by: mgoin <mgoin64@gmail.com>

* Fix `huggingface-cli[hf-xet]` -> `huggingface-cli[hf_xet]` (vllm-project#15969)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [V1][TPU] TPU-optimized top-p implementation (avoids scattering). (vllm-project#15736)

Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>

* [TPU] optimize the all-reduce performance (vllm-project#15903)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* [V1][TPU] Do not compile sampling more than needed (vllm-project#15883)

Signed-off-by: NickLucche <nlucches@redhat.com>

* [ROCM][KERNEL] Paged attention for V1 (vllm-project#15720)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: root <root@banff-cyxtera-s65-4.amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: root <root@banff-cyxtera-s65-4.amd.com>

* fix: better error message for get_config close vllm-project#13889 (vllm-project#15943)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [bugfix] add seed in torchrun_example.py (vllm-project#15980)

Signed-off-by: youkaichao <youkaichao@gmail.com>

* [ROCM][V0] PA kennel selection when no sliding window provided (vllm-project#15982)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>

* [Benchmark] Add AIMO Dataset to Benchmark (vllm-project#15955)

Signed-off-by: Ziji Shi <shi.ziji.sm@gmail.com>
Signed-off-by: StevenShi-23 <shi.ziji.sm@gmail.com>

* [misc] improve error message for "Failed to infer device type" (vllm-project#15994)

Signed-off-by: youkaichao <youkaichao@gmail.com>

* [Bugfix][V1] Fix bug from putting llm_engine.model_executor in a background process (vllm-project#15367)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>

* [doc] update contribution link (vllm-project#15922)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* fix: tiny fix make format.sh excutable (vllm-project#16015)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [SupportsQuant] Bert, Blip, Blip2, Bloom (vllm-project#15573)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* [SupportsQuant] Chameleon, Chatglm, Commandr (vllm-project#15952)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>

* [Neuron][kernel] Fuse kv cache into a single tensor (vllm-project#15911)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>

* [Minor] Fused experts refactor (vllm-project#15914)

Signed-off-by: Bill Nell <bnell@redhat.com>

* [Misc][Performance] Advance tpu.txt to the most recent nightly torch … (vllm-project#16024)

* Re-enable the AMD Testing for the passing tests. (vllm-project#15586)

Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com>

* [TPU] Support sliding window and logit soft capping in the paged attention kernel for TPU. (vllm-project#15732)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>

* [TPU] Switch Test to Non-Sliding Window (vllm-project#15981)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>

* [Bugfix] Fix function names in test_block_fp8.py (vllm-project#16033)

Signed-off-by: Bill Nell <bnell@redhat.com>

* [ROCm] Tweak the benchmark script to run on ROCm (vllm-project#14252)

* [Misc] improve gguf check (vllm-project#15974)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [TPU][V1] Remove ragged attention kernel parameter hard coding (vllm-project#16041)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* doc: add info for macos clang errors (vllm-project#16049)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [V1][Spec Decode] Avoid logging useless nan metrics (vllm-project#16023)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>

* [Model] use AutoWeightsLoader for baichuan, gpt-neox, mpt (vllm-project#15939)

Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>

* [Hardware][Gaudi][BugFix] fix arguments of hpu fused moe (vllm-project#15945)

Signed-off-by: zhenwei <zhenweiliu@habana.ai>

* [Bugfix][kernels] Fix half2float conversion in gguf kernels (vllm-project#15995)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [Benchmark][Doc] Update throughput benchmark and README (vllm-project#15998)

Signed-off-by: StevenShi-23 <shi.ziji.sm@gmail.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <ywang@roblox.com>

* [CPU] Change default block_size for CPU backend (vllm-project#16002)

Signed-off-by: jiang1.li <jiang1.li@intel.com>

* [Distributed] [ROCM] Fix custom allreduce enable checks (vllm-project#16010)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>

* [ROCm][Bugfix] Use platform specific FP8 dtype (vllm-project#15717)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

* [ROCm][Bugfix] Bring back fallback to eager mode removed in vllm-project#14917, but for ROCm only (vllm-project#15413)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

* [Bugfix] Fix default behavior/fallback for pp in v1 (vllm-project#16057)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [CI] Reorganize .buildkite directory (vllm-project#16001)

Signed-off-by: kevin <kevin@anyscale.com>

* [V1] DP scale-out (1/N): Use zmq ROUTER/DEALER sockets for input queue (vllm-project#15906)

Signed-off-by: Nick Hill <nhill@redhat.com>

* [V1] Scatter and gather placeholders in the model runner (vllm-project#15712)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>

* Revert "[V1] Scatter and gather placeholders in the model runner" (vllm-project#16075)

* [Kernel][Minor] Re-fuse triton moe weight application (vllm-project#16071)

Signed-off-by: Bill Nell <bnell@redhat.com>

* [Bugfix][TPU] Fix V1 TPU worker for sliding window (vllm-project#16059)

Signed-off-by: Michael Goin <mgoin64@gmail.com>

* [V1][Spec Decode] Update N-gram Proposer Interface (vllm-project#15750)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* [Misc] Auto detect bitsandbytes pre-quantized models (vllm-project#16027)

Signed-off-by: Tristan Leclercq <tristanleclercq@gmail.com>

* [CI] Fix benchmark script level (vllm-project#16089)

* fix: support clang17 for macos and fix the real libomp (vllm-project#16086)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [doc] fix 404 (vllm-project#16082)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* Revert "doc: add info for macos clang errors (vllm-project#16049)" (vllm-project#16091)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* Fix some capitalisations in generated examples doc titles (vllm-project#16094)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Misc] format output for encoder_decoder.py (vllm-project#16095)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Misc] Remove redundant code (vllm-project#16098)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* [Bugfix] fix use_atomic_add support of marlin kernel when using v1 engine (vllm-project#15946)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>

* [Model] use AutoWeightsLoader for phi, gemma, deepseek (vllm-project#16088)

Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>

* [Model] fix model testing for TeleChat2ForCausalLM and V0 llama4 (vllm-project#16112)

Signed-off-by: Lu Fang <fanglu@fb.com>

* [Benchmark] Add sampling parameters to benchmark_serving. (vllm-project#16022)

Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>

* [Frontend] Fix typo in tool chat templates for llama3.2 and toolace (vllm-project#14501)

Signed-off-by: Ben Jackson <ben@ben.com>

* [CI][V1] Fix passing `tokenizer` as kwarg to `validate_guidance_grammar` (vllm-project#16117)

Signed-off-by: Roger Wang <ywang@roblox.com>

* [Misc] refactor example eagle (vllm-project#16100)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Doc][Bugfix] Add missing EOF in k8s deploy doc (vllm-project#16025)

* [Misc] Improve model redirect to accept json dictionary (vllm-project#16119)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [Model] use AutoWeightsLoader for stablelm,starcoder2,zamba2 (vllm-project#16103)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>

* [Bugfix] LoRA : Fix the order in which the kernels process LoRAs  (vllm-project#16040)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>

* [Bugfix] add hf_token to EngineArgs (vllm-project#16093)

Signed-off-by: paolovic <paul-philipp.luley@uzh.ch>
Co-authored-by: paolovic <paul-philipp.luley@uzh.ch>

* [Misc] update requires-python in pyproject.toml (vllm-project#16116)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [TPU] Update PyTorch/XLA (vllm-project#16130)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* [V1][Minor] Optimize get_cached_block (vllm-project#16135)

* Fix requires-python (vllm-project#16132)

* [Metrics] Add bucket for `request_latency`, `time_to_first_token` and `time_per_output_token` (vllm-project#15202)

Signed-off-by: Kay Yan <kay.yan@daocloud.io>

* [V1][Minor] Minor simplification for get_computed_blocks  (vllm-project#16139)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

* [Misc] Update Mistral-3.1 example (vllm-project#16147)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Make dummy encoder prompt padding alternative and add missing warnings (vllm-project#16129)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [CI] Set max transformers version for Ultravox model test  (vllm-project#16149)

Signed-off-by: Roger Wang <ywang@roblox.com>

* doc: fix some typos in doc (vllm-project#16154)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [VLM] Florence-2 supports online serving (vllm-project#16164)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [V1][Structured Output] Add `supports_structured_output()` method to Platform (vllm-project#16148)

Signed-off-by: shen-shanshan <467638484@qq.com>

* [Model] Add Qwen3 and Qwen3MoE (vllm-project#15289)

Signed-off-by: YamPengLi <yampayne.lyp@alibaba-inc.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* [Misc] improve example mlpspeculator and llm_engine_example (vllm-project#16175)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Doc]Update image to latest version (vllm-project#16186)

Signed-off-by: WangErXiao <863579016@qq.com>

* Upstream Llama4 Support to Main (vllm-project#16113)

Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com>
Signed-off-by: Chris Thi <chris.c.thi@gmail.com>
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: Jon Swenson <jmswen@gmail.com>
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Xiaodong Wang <xdwang@meta.com>
Signed-off-by: Yang Chen <yangche@fb.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Re-enable support for `ChatGLMForConditionalGeneration` (vllm-project#16187)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [V1] Revert the default `max_num_seqs` to V0 values for most hardware (vllm-project#16158)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* Print the warning only once (vllm-project#16193)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

* [Misc] Human-readable `max-model-len` cli arg (vllm-project#16181)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>

* [Misc] Move Llama 4 projector call into encoder execution (vllm-project#16201)

* [Bugfix] Fix guidance backend for Qwen models (vllm-project#16210)

Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>

* [V1][BugFix] Exit properly if engine core fails during startup (vllm-project#16137)

Signed-off-by: Nick Hill <nhill@redhat.com>

* [Misc] add description attribute in CLI (vllm-project#15921)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Bugfix][V0] XGrammar structured output supports Enum (vllm-project#15878)

Signed-off-by: Leon Seidel <leon.seidel@fau.de>

* Torchao (vllm-project#14231)

Signed-off-by: drisspg <drisspguessous@gmail.com>

* [ROCm][Bugfix][FP8] Make fp8 quant respect fused modules mapping (vllm-project#16031)

Signed-off-by: mgoin <michael@neuralmagic.com>

* [core] do not send error across process (vllm-project#16174)

Signed-off-by: youkaichao <youkaichao@gmail.com>

* [Misc] Update compressed-tensors to version 0.9.3 (vllm-project#16196)

Signed-off-by: Miles Williams <42222518+mlsw@users.noreply.github.com>

* Update BASE_IMAGE to 2.22 release of Neuron (vllm-project#16218)

* [V1] Scatter and gather placeholders in the model runner (vllm-project#16076)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Jennifer Zhao <ai.jenniferzhao@gmail.com>

* [Bugfix] fix use-ep bug to enable ep by dp/tp size > 1 (vllm-project#16161)

* Add warning for Attention backends that do not support irope yet (vllm-project#16212)

* [Bugfix] Do not skip "empty" parts of chats that are parsable (vllm-project#16219)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Bugfix] Fix and reorganize broken GGUF tests and bump gguf version (vllm-project#16194)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [torch.compile][TPU] Make @support_torch_compile work for XLA backend (vllm-project#15782)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>

* [V1] Add `disable_chunked_mm_input` arg to disable partial mm input prefill (vllm-project#15837)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Misc] Merge the logs of pp layers partitions (vllm-project#16225)

Signed-off-by: Kebe <mail@kebe7jun.com>

* [Docs] Add Slides from Singapore Meetup (vllm-project#16213)

Signed-off-by: simon-mo <simon.mo@hey.com>

* [Misc] format and refactor some examples (vllm-project#16252)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [Misc] Add warning for multimodal data in LLM.beam_search (vllm-project#16241)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

* [Model] use AutoWeightsLoader for phimoe,qwen2_moe,qwen3_moe (vllm-project#16203)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>

* [BugFix][ROCm] Fix GGUF MoE Dispatch Block_Dim for ROCm (vllm-project#16247)

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* [Bugfix] Remove triton do_bench fast_flush arg (vllm-project#16256)

Signed-off-by: Kebe <mail@kebe7jun.com>

* Update to transformers==4.51.1 (vllm-project#16257)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [New Model]: jinaai/jina-embeddings-v3 (vllm-project#16120)

* [Misc] Avoid stripping meaningful whitespace from `nvidia-smi topo -m` output in collect_env.py (vllm-project#16272)

Signed-off-by: imkero <kerorek@outlook.com>

* [Bugfix] Proper input validation for multi-modal encoder-decoder models (vllm-project#16156)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Handle `process_weights_after_loading` for `QKVCrossParallelLinear` (vllm-project#15328)

Signed-off-by: Isotr0py <2037008807@qq.com>

* Add warning that content below line in template will be removed (vllm-project#16276)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [BugFix] Fix Llama4 - Index Error When Single Request Near Max Context (vllm-project#16209)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>

* [Bugfix] fix deepseek fp16 scale bug (vllm-project#14809)

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>

* [V1] Update structured output offline inference example (vllm-project#15721)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* [CI/Build] Fix CI LoRA failure (vllm-project#16270)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>

* Add support to modelopt quantization of Mixtral model (vllm-project#15961)

Signed-off-by: Yue <yueshen@nvidia.com>

* [Model] Add smolvlm support (vllm-project#16017)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* [Bug] [ROCm] Fix Llama 4 Enablement Bug on ROCm: V0 ROCmFlashAttentionImpl and Triton Fused MoE bugs (vllm-project#16198)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Co-authored-by: Hongxia Yang <hongxia.yang@amd.com>
Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com>

* [Bugfix] fix gettid method is not define (vllm-project#16084)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>

* [Feature] Estimate max-model-len use available KV cache memory (vllm-project#16168)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>

* [Core] Upgrade to xgrammar 0.1.18, add cache size limit (vllm-project#16283)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* [CI][Bugfix] Fix bad tolerance for test_batch_base64_embedding (vllm-project#16221)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [TPU] Update PyTorch/XLA (vllm-project#16288)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* [BugFix] Fix fusion test and add them to CI (vllm-project#16287)

Signed-off-by: luka <luka@neuralmagic.com>

* [Misc] Fix test_sharded_state_loader.py(vllm-project#16004) (vllm-project#16005)

Signed-off-by: lvfei.lv <lvfei.lv@alibaba-inc.com>

* [Bugfix] Avoid transferring cached multi-modal items from P0 to P1 (vllm-project#16273)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* Update label-tpu mergify and remove removal bot (vllm-project#16298)

* [BugFix] logger is not callable (vllm-project#16312)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [BugFix] llama4 qknorm should be not shared across head (vllm-project#16311)

Signed-off-by: Lu Fang <fanglu@fb.com>

* update neuron config (vllm-project#16289)

Signed-off-by: Ajay Vohra <ajayvohr@amazon.com>

* [BugFix] fix some typos found by typos. (vllm-project#16314)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>

* [Model] Add `SupportsMultiModal.get_language_model` interface (vllm-project#16007)

Signed-off-by: NickLucche <nlucches@redhat.com>

* [Bugfix][Frontend] respect provided default guided decoding backend (vllm-project#15476)

Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>

* Revert "Update label-tpu mergify and remove removal bot" (vllm-project#16350)

* [Bugfix] Fix profiling.py (vllm-project#16202)

Signed-off-by: zh Wang <rekind133@outlook.com>

* [Bugfix] catch AssertionError in MistralTokenizer as ValueError (vllm-project#16344)

Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>

* [CI]Fix hpu docker and numpy version for CI (vllm-project#16355)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>

* Fix `benchmark_throughput.py --backend=hf` (vllm-project#16352)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Build/CI] Add tracing deps to vllm container image (vllm-project#15224)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* [Hardware] add platform-specific request validation api (vllm-project#16291)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>

* [Misc] refactor Structured Outputs example (vllm-project#16322)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [TPU][V1] Refine tpu_model_runner to mitigate future recompilation issues (vllm-project#16275)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* Add GLM-4-0414 support (vllm-project#16338)

Signed-off-by: lvfei.lv <lvfei.lv@alibaba-inc.com>
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Ajay Vohra <ajayvohr@amazon.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>
Co-authored-by: Accelerator1996 <lvfei.lv@alibaba-inc.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: yihong <zouzou0208@gmail.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Co-authored-by: ajayvohra2005 <ajayvohr@amazon.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Guillaume Calmettes <gcalmettes@scaleway.com>

* [Bugfix]: do not shutdown server if `skip_special_use=False` for MistralTokenizer (vllm-project#14094)

Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>

* [Model] use AutoWeightsLoader for granite, granitemoe, granitemoeshared, grok1, mixtral (vllm-project#16325)

Signed-off-by: Aaron Ang <aaron.angyd@gmail.com>

* [TPU] Fix dummy loading OOM (vllm-project#16372)

Signed-off-by: Chengji Yao <chengjiyao@google.com>

* [bugfix] Avoid the time consumption caused by creating dummy videos. (vllm-project#16371)

* [CI][Bugfix] Pin triton version for CPU (vllm-project#16384)

Signed-off-by: Roger Wang <ywang@roblox.com>

* [misc] use tqdm.auto where appropriate (vllm-project#16290)

Signed-off-by: Benjamin Kitor <bkitor@gigaio.com>

* [Bugfix][TPU] Fix TPU validate_request (vllm-project#16369)

Signed-off-by: Michael Goin <mgoin64@gmail.com>

* fix sonnet dataset sample when prefix len is very small (vllm-project#16379)

Signed-off-by: Chenyaaang <chenyangli@google.com>

* [Model] use AutoWeightsLoader for deepseek_v2, internlm2 (vllm-project#16383)

Signed-off-by: Aaron Ang <aaron.angyd@gmail.com>

* [Misc] Update transformers version limits of multi-modal tests (vllm-project#16381)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Fix validation error for text-only Mllama 3.2 (vllm-project#16377)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Kernel] Use moe_wna16 kernel for compressed tensors wna16 moe models (vllm-project#16038)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [doc] add download model tips (vllm-project#16389)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* Update Numba to 0.61.2 (vllm-project#16376)

Signed-off-by: cyy <cyyever@outlook.com>

* [Model] Remove image mm limit for LLaMa4  (vllm-project#16365)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>

* [doc] update the wrong link (vllm-project#16401)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* [CI] Add auto update workflow for Dockerfile graph (vllm-project#11879)

Signed-off-by: wineandchord <guoqizhou19@gmail.com>

* Fix the torch version parsing logic (vllm-project#15857)

* [VLM] Remove `BaseProcessingInfo.get_mm_max_tokens_per_item` (vllm-project#16408)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [TPU][V1] Use `language_model` interface for getting text backbone in MM (vllm-project#16410)

Signed-off-by: NickLucche <nlucches@redhat.com>

* Improve configs - `ParallelConfig` (vllm-project#16332)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [V1] Set structured output backend to `auto` by default (vllm-project#15724)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* [V1][Spec Decode] Eagle Model loading (vllm-project#16035)

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>

* [Bugfix] Fix bug when dataset is json (vllm-project#15899)

Signed-off-by: Chenyaaang <chenyangli@google.com>

* [Model] Reduce redundant computations in mamba2 blocks for Bamba-9B (vllm-project#15423)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>

* [V1] Zero-copy tensor/ndarray serialization/transmission (vllm-project#13790)

Signed-off-by: Nick Hill <nhill@redhat.com>

* [VLM] Avoid unnecessary dummy multimodal data during processing (vllm-project#16416)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Fix output token length check logic (vllm-project#16419)

Signed-off-by: look <eeslook@163.com>

* [TPU][V1] Disable per-request seed/Generator (vllm-project#16172)

Signed-off-by: NickLucche <nlucches@redhat.com>

* Fix range_ratio Bug in RandomDataset (vllm-project#16126)

Signed-off-by: jadewang21 <jadewangcn@outlook.com>

* check input length of sonnet samples (vllm-project#16423)

Signed-off-by: alexey-belyakov <alexey.belyakov@intel.com>

* update benchmark_serving_structured_output to include auto backend (vllm-project#16438)

Signed-off-by: Chenyaaang <chenyangli@google.com>

* [Llama4] Enable attention temperature tuning by default for long context (>32k) (vllm-project#16439)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>

* Update supported_hardware.md for TPU INT8 (vllm-project#16437)

* [Bugfix][VLM] Fix failing Phi-4-MM multi-images tests and add vision-speech test (vllm-project#16424)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [CPU][Bugfix] Fix CPU docker issues (vllm-project#16454)

Signed-off-by: jiang.li <jiang1.li@intel.com>

* [Bugfix] Don't set an upper bound on repetition penalty (vllm-project#16403)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Nick Hill <nhill@redhat.com>

* Revert "[Model] use AutoWeightsLoader for deepseek_v2, internlm2" (vllm-project#16453)

* [Core][LoRA][1/N] Add LoRA for EncoderDecoderModelRunner (vllm-project#15990)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>

* Enforce valid max_num_batched_tokens when disable_chunked_mm_input=True (vllm-project#16447)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Misc] Raise error for V1 not supporting Long LoRA. (vllm-project#16415)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>

* [Misc] update api_client example (vllm-project#16459)

Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>

* Don't install triton on `ppc64le` platform (vllm-project#16470)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Kernel] support merge_attn_states CUDA kernel, 3x speedup (vllm-project#16173)

Signed-off-by: DefTruth <qiustudent_r@163.com>

* [Bugfix] Fix bugs of running Quark quantized models (vllm-project#16236)

Signed-off-by: chaow <chaow@amd.com>

* [Hardware][Intel-Gaudi] Multi-step scheduling implementation for HPU (vllm-project#12779)

Signed-off-by: Tomasz Zielinski <tomasz.zielinski@intel.com>

* Fix erroneous "model doesn't support compile" warning (vllm-project#16486)

Signed-off-by: rzou <zou3519@gmail.com>

* [TPU][V1] Make `--disable_chunked_mm_input` mandatory for serving MM models (vllm-project#16483)

Signed-off-by: NickLucche <nlucches@redhat.com>

* [Kernel] Support W8A8 channel-wise weights and per-token activations in triton fused_moe_kernel (vllm-project#16366)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Doc] Document InternVL3 support (vllm-project#16495)

Signed-off-by: Isotr0py <2037008807@qq.com>

* [Bugfix] handle alignment of encoder_seq_lens in mllama.py (vllm-project#14784)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>

* Improve configs - `LoadConfig` (vllm-project#16422)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Frontend] Added chat templates for LLaMa4 pythonic tool calling (vllm-project#16463)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Kai Wu <kaiwu@meta.com>

* [Kernel] Add tuned FusedMoE kernel config for Llama4 Scout, TP=8 on H100  (vllm-project#16488)

* Update openai_compatible_server.md (vllm-project#16507)

Signed-off-by: Christian Sears <csears@redhat.com>

* [Bugfix] clean up duplicated code (vllm-project#16485)

Signed-off-by: Gogs <gogs@fake.local>
Co-authored-by: Gogs <gogs@fake.local>

* Bugfix for PixtralHF models without spatial_merge_size (vllm-project#16513)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Doc] Fix link to vLLM blog (vllm-project#16519)

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>

* [CI][Bugfix] Add mistral_tool_use to Ci (vllm-project#16517)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [BugFix] Handle non-contiguous tensors properly when serializing (vllm-project#16492)

Signed-off-by: Nick Hill <nhill@redhat.com>

* [Doc] Update Llama4 Model Names in Supported Models (vllm-project#16509)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>

* Optimized topk for topk=1 (Llama-4) (vllm-project#16512)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Feature][V1] Add xgrammar to support minLength, maxLength with test (vllm-project#16516)

Signed-off-by: Leon Seidel <leon.seidel@fau.de>

* [Frontend] support matryoshka representation / support embedding API dimensions (vllm-project#16331)

* fix: spelling (vllm-project#16466)

Signed-off-by: Tianer Zhou <ezhoureal@gmail.com>

* [Misc] Update chat utils tests (vllm-project#16520)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Misc] Openai transcription client example use same Whisper model (vllm-project#16487)

Signed-off-by: NickLucche <nlucches@redhat.com>

* [V1] Enable multi-input by default (vllm-project#15799)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [MISC] Make GroupCoordinator compatible with out-of-tree devices (vllm-project#16464)

Signed-off-by: hzji210@gmail.com <hzji210@gmail.com>

* [Misc] Delete redundant code (vllm-project#16530)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* Fix syntaxWarning: invalid escape sequence '\s' (vllm-project#16532)

Signed-off-by: Jie Fu <jiefu@tencent.com>

* [Perf] Optimize Preparing Inputs for GPU Model Runner (vllm-project#16484)

Signed-off-by: snowcharm <snowcharmqq@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>

* [Bugfix] Validate logit biases to prevent out of vocab ids crashing engine (vllm-project#16529)

Signed-off-by: Ryan McConville <ryan@ryanmcconville.com>

* [V1][Spec Decode] KV cache slots for eagle heads (vllm-project#16370)

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>

* Enable PTPC FP8 for CompressedTensorsW8A8Fp8MoEMethod (triton fused_moe) (vllm-project#16537)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [Benchmark][Bugfix] Fix SonnetDataset default values in benchmark_throughput.py (vllm-project#16556)

* [Core][V0] Enable regex support with xgrammar (vllm-project#13228)

Signed-off-by: Russell Bryant <rbryant@redhat.com>

* capture only SP * batch_size <= max_batch_size case to cover small max_batch_size

---------

Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: chun37 <chun.jb.37@gmail.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Chris Thi <chris.c.thi@gmail.com>
Signed-off-by: lukas.bluebaum <lukas.bluebaum@aleph-alpha.com>
Signed-off-by: Eric <erictang000@gmail.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Kay Yan <kay.yan@daocloud.io>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Signed-off-by: Liangfu Chen <liangfc@amazon.com>
Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
Signed-off-by: Nishidha Panpaliya <nishidha.panpaliya@partner.ibm.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: root <root@banff-cyxtera-s65-4.amd.com>
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
Signed-off-by: Ziji Shi <shi.ziji.sm@gmail.com>
Signed-off-by: StevenShi-23 <shi.ziji.sm@gmail.com>
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: reidliu41 <reid201711@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com>
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Jonghyun Choe <andy.choe729@gmail.com>
Signed-off-by: zhenwei <zhenweiliu@habana.ai>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: ilmarkov <imarkov@redhat.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: kevin <kevin@anyscale.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Tristan Leclercq <tristanleclercq@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Ben Jackson <ben@ben.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: paolovic <paul-philipp.luley@uzh.ch>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: YamPengLi <yampayne.lyp@alibaba-inc.com>
Signed-off-by: WangErXiao <863579016@qq.com>
Signed-off-by: Aston Zhang <22279212+astonzhang@users.noreply.github.com>
Signed-off-by: drisspg <drisspguessous@gmail.com>
Signed-off-by: Jon Swenson <jmswen@gmail.com>
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
Signed-off-by: Lu Fang <fanglu@meta.com>
Signed-off-by: Xiaodong Wang <xdwang@meta.com>
Signed-off-by: Yang Chen <yangche@fb.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
Signed-off-by: Lu Fang <lufang@fb.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Leon Seidel <leon.seidel@fau.de>
Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: Miles Williams <42222518+mlsw@users.noreply.github.com>
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Signed-off-by: Kebe <mail@kebe7jun.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
Signed-off-by: imkero <kerorek@outlook.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Yue <yueshen@nvidia.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: kliuae <kuanfu.liu@embeddedllm.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: lvfei.lv <lvfei.lv@alibaba-inc.com>
Signed-off-by: Ajay Vohra <ajayvohr@amazon.com>
Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com>
Signed-off-by: zh Wang <rekind133@outlook.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com>
Signed-off-by: Aaron Ang <aaron.angyd@gmail.com>
Signed-off-by: Benjamin Kitor <bkitor@gigaio.com>
Signed-off-by: Chenyaaang <chenyangli@google.com>
Signed-off-by: cyy <cyyever@outlook.com>
Signed-off-by: wineandchord <guoqizhou19@gmail.com>
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Signed-off-by: look <eeslook@163.com>
Signed-off-by: jadewang21 <jadewangcn@outlook.com>
Signed-off-by: alexey-belyakov <alexey.belyakov@intel.com>
Signed-off-by: jiang.li <jiang1.li@intel.com>
Signed-off-by: DefTruth <qiustudent_r@163.com>
Signed-off-by: chaow <chaow@amd.com>
Signed-off-by: Tomasz Zielinski <tomasz.zielinski@intel.com>
Signed-off-by: rzou <zou3519@gmail.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Christian Sears <csears@redhat.com>
Signed-off-by: Gogs <gogs@fake.local>
Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
Signed-off-by: Tianer Zhou <ezhoureal@gmail.com>
Signed-off-by: hzji210@gmail.com <hzji210@gmail.com>
Signed-off-by: Jie Fu <jiefu@tencent.com>
Signed-off-by: snowcharm <snowcharmqq@gmail.com>
Signed-off-by: Ryan McConville <ryan@ryanmcconville.com>
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Co-authored-by: Thien Tran <gau.nernst@yahoo.com.sg>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: chun <chun.jb.37@gmail.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Chris Thi <chris.c.thi@gmail.com>
Co-authored-by: LukasBluebaum <38468743+LukasBluebaum@users.noreply.github.com>
Co-authored-by: Eric Tang <46737979+erictang000@users.noreply.github.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Kay Yan <kay.yan@daocloud.io>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Matthias Matt <37695050+meffmadd@users.noreply.github.com>
Co-authored-by: Liangfu Chen <liangfc@amazon.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Li, Jiang <jiang1.li@intel.com>
Co-authored-by: rongfu.leng <lenronfu@gmail.com>
Co-authored-by: Nishidha <nishidha.panpaliya@partner.ibm.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Hyesoo Yang <45211235+hyeygit@users.noreply.github.com>
Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
Co-authored-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: root <root@banff-cyxtera-s65-4.amd.com>
Co-authored-by: yihong <zouzou0208@gmail.com>
Co-authored-by: Ziji Shi (Steven) <shi.ziji.sm@gmail.com>
Co-authored-by: wwl2755 <wangwenlong2755@gmail.com>
Co-authored-by: Reid <61492567+reidliu41@users.noreply.github.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com>
Co-authored-by: yarongmu-google <150371854+yarongmu-google@users.noreply.github.com>
Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com>
Co-authored-by: iefgnoix <isaacwxf23@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Huy Do <huydhn@gmail.com>
Co-authored-by: Jonghyun Choe <andy.choe729@gmail.com>
Co-authored-by: liuzhenwei <zhenweiliu@habana.ai>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Ilya Markov <markovilya197@gmail.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: Kevin H. Luu <kevin@anyscale.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Tristan Leclercq <49700633+tristanleclercq@users.noreply.github.com>
Co-authored-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Co-authored-by: Ben Jackson <ben@ben.com>
Co-authored-by: Paul Schweigert <paul@paulschweigert.com>
Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io>
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: paolovic <91155454+paolovic@users.noreply.github.com>
Co-authored-by: paolovic <paul-philipp.luley@uzh.ch>
Co-authored-by: Martin Hoyer <mhoyer@redhat.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: YamPengLi <yampayne.lyp@alibaba-inc.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Robin <863579016@qq.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
Co-authored-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Co-authored-by: leon-seidel <83984854+leon-seidel@users.noreply.github.com>
Co-authored-by: Driss Guessous <32754868+drisspg@users.noreply.github.com>
Co-authored-by: Miles Williams <42222518+mlsw@users.noreply.github.com>
Co-authored-by: Satyajith Chilappagari <satchill@amazon.com>
Co-authored-by: Jennifer Zhao <ai.jenniferzhao@gmail.com>
Co-authored-by: zxfan-cpu <zxfanzhang@tencent.com>
Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com>
Co-authored-by: Siyuan Liu <lsiyuan@google.com>
Co-authored-by: Kebe <mail@kebe7jun.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Alex Brooks <alex.brooks@ibm.com>
Co-authored-by: TY-AMD <tianyuan.wu@amd.com>
Co-authored-by: wang.yuqi <noooop@126.com>
Co-authored-by: Kero Liang <kerorek@outlook.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: yueshen2016 <39203804+yueshen2016@users.noreply.github.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
Co-authored-by: Hongxia Yang <hongxia.yang@amd.com>
Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Accelerator1996 <lvfei.lv@alibaba-inc.com>
Co-authored-by: ajayvohra2005 <ajayvohr@amazon.com>
Co-authored-by: Guillaume Calmettes <gcalmettes@scaleway.com>
Co-authored-by: zh Wang <rekind133@outlook.com>
Co-authored-by: Chendi.Xue <chendi.xue@intel.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Yuxuan Zhang <2448370773@qq.com>
Co-authored-by: Aaron Ang <67321817+aaron-ang@users.noreply.github.com>
Co-authored-by: Jintao <huangjintao@mail.ustc.edu.cn>
Co-authored-by: Benjamin Kitor <bkitor@gigaio.com>
Co-authored-by: Chenyaaang <42742451+Chenyaaang@users.noreply.github.com>
Co-authored-by: cyyever <cyyever@outlook.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: wineandchord <guoqizhou123123@qq.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Co-authored-by: Lily Liu <lilyliupku@gmail.com>
Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Co-authored-by: look <eeslook@163.com>
Co-authored-by: WWW <jadewangcn@outlook.com>
Co-authored-by: Alexey Belyakov <alexey.belyakov@intel.com>
Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
Co-authored-by: chaow-amd <chaow@amd.com>
Co-authored-by: Tomasz Zielinski <85164140+tzielinski-habana@users.noreply.github.com>
Co-authored-by: Richard Zou <zou3519@users.noreply.github.com>
Co-authored-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Kai Wu <kaiwu@meta.com>
Co-authored-by: Christian Sears <117944059+Chr1st1anSears@users.noreply.github.com>
Co-authored-by: Gogs <gogs@fake.local>
Co-authored-by: Yuan Tang <terrytangyuan@gmail.com>
Co-authored-by: Tianer Zhou <ezhoureal@gmail.com>
Co-authored-by: Huazhong Ji <hzji210@gmail.com>
Co-authored-by: Jie Fu (傅杰) <jiefu@tencent.com>
Co-authored-by: SnowCharm <qiuyilun@u.nus.edu>
Co-authored-by: Ryan McConville <ryan@ryanmcconville.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants