Skip to content

Conversation

@WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Sep 19, 2025

Key Changes

  • Remove persistent batch
    • No “reordering” & complex bookkeeping
    • Almost all CPU states are Numpy arrays → We can vectorize most of the Python loops in pre-/post-processing
    • Simpler handling for requests resumed from preemption
  • GPU-persistent block tables
    • The CPU does not have the block tables at all. GPU maintains the persistent block tables.
    • In every step, we only send the “diff”s to the GPU, and use a Triton kernel to update the persistent block tables
    • We also use another Triton kernel to create new ephemeral block tables used for this forward pass.
    • More scalable as max_model_len and num_kv_groups increase
  • Use Numpy arrays (instead of Python lists) for ModelRunnerOutput & SchedulerOutput
    • Enables vectorization for input pre-processing
    • Lower serialization & GC overheads
  • Triton-native sampler
    • No -1 temperature hack for greedy sampling
    • Efficient support for per-request seeds
    • Efficient support for logprobs by only materializing the top-k logprobs instead of the whole vocab
  • Data-parallel sampler (when TP > 1, optional)
    • Current: Compute logits → all-gather → sampler
    • V2: Compute logits → all gather (ideally all2all) → sampler → all gather

TODOs

  • CUDA graphs
  • Efficient support for prompt logprobs

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
WoosukKwon and others added 6 commits September 20, 2025 11:43
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Copy link
Contributor

@yinghai yinghai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm!

# Request not found.
return
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be easier to make this a set so that we can check uniqueness?

for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
layer_names.add(layer_name)
assert layer_names == set(kv_cache_raw_tensors.keys()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe cache this

attn_layers = get_layers_from_vllm_config(vllm_config, Attention)
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
layer_names = kv_cache_group_spec.layer_names
any_layer_name = next(iter(layer_names))
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to assume an always on hybrid-kv-cache manager; I am 100% supportive of this but the iirc the reason we still supported disabling the hybrid-kv-cache manager was because P/D did not support the hybrid kv-cache manager yet; cc @NickLucche @heheda12345 do you know the state of P/D + hybrid-kv-cache? is there any other reason we would want to disable the hybrid-kv-cache?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still not supported on PD, @KuntaiDu has a series of PRs to enable hybrid allocator + kv connectors first

[scheduler_output.num_scheduled_tokens[i] for i in req_ids],
dtype=np.int32)

# TODO(woosuk): Support CUDA graphs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! happy to see this above attention metadata prep; we should definitely treat full-cudagraphs as first-class and I think padding for cudagraphs before attention metadata prep has been a major source of headaches/bugs; very excited about this!

WoosukKwon and others added 2 commits September 24, 2025 08:19
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon great work, I really like the direction.

Left a few comments of random things I noticed.

And of course there will be some work to integrate with other features as discussed.

def add_request(
self,
req_id: str,
prompt_token_ids: list[int],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can also look at keeping things as np arrays end to end, or array.array for lists that grow (for later).

)
self.np = np.zeros_like(self.buffer.np)

def copy_np_to_gpu(self, x: np.ndarray) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have this take the mapping?

# TODO(woosuk): Support CUDA graphs.
num_tokens_after_padding = num_tokens

idx_mapping_list = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we skip updating idx_mapping and recreating sampling metadata if req_ids list hasn't changed?

@josephrocca
Copy link

josephrocca commented Oct 2, 2025

I'm not sure if this is expected behavior, given that this PR is only a draft, but when testing this PR with -dcp 4 for DeepSeek R1/V3, on a 4xH200, I get gibberish outputs, and eventually a crash.

If I remove -dcp 4 then it works fine.

git clone --branch woosuk/model-runner-v2 https://github.com/vllm-project/vllm && cd vllm && git reset --hard 866eef50cae7f9a5f10dbbad8cdf34d07c943f1b
VLLM_USE_PRECOMPILED=1 uv pip install --editable .[flashinfer]
vllm serve RedHatAI/DeepSeek-R1-0528-quantized.w4a16 --tensor-parallel-size 4 -dcp 4 --async-scheduling --served-model-name default --max-model-len 9216

Then I send about 50 concurrent requests at it for a minute or so (with several requests sharing a prefix - unsure if that matters), and it crashes.

Click to see crash logs
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671] WorkerProc hit an exception.
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671] Traceback (most recent call last):
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]   File "/root/vllm/vllm/v1/executor/multiproc_executor.py", line 666, in worker_busy_loop
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     output = func(*args, **kwargs)
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]   File "/root/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     return func(*args, **kwargs)
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]   File "/root/vllm/vllm/v1/worker/gpu_worker.py", line 443, in execute_model
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     return self.model_runner.execute_model(scheduler_output)
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]   File "/root/vllm/vllm/v1/worker/gpu/model_runner.py", line 476, in execute_model
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     return self.postprocess(sampler_output, input_batch)
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]   File "/root/vllm/vllm/v1/worker/gpu/model_runner.py", line 424, in postprocess
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671] RuntimeError: shape mismatch: value tensor of shape [33, 4] cannot be broadcast to indexing result of shape [129, 1]
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671] Traceback (most recent call last):
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]   File "/root/vllm/vllm/v1/executor/multiproc_executor.py", line 666, in worker_busy_loop
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     output = func(*args, **kwargs)
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]   File "/root/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     return func(*args, **kwargs)
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]   File "/root/vllm/vllm/v1/worker/gpu_worker.py", line 443, in execute_model
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     return self.model_runner.execute_model(scheduler_output)
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]   File "/root/vllm/vllm/v1/worker/gpu/model_runner.py", line 476, in execute_model
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     return self.postprocess(sampler_output, input_batch)
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]   File "/root/vllm/vllm/v1/worker/gpu/model_runner.py", line 424, in postprocess
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     self.req_states.last_sampled_tokens[input_batch.idx_mapping] = (
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671]     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671] RuntimeError: shape mismatch: value tensor of shape [33, 4] cannot be broadcast to indexing result of shape [129, 1]
(Worker_TP0 pid=86929) ERROR 10-02 09:11:38 [multiproc_executor.py:671] 
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [dump_input.py:69] Dumping input data for V1 LLM engine (v0.10.2rc3.dev562+g866eef50c) with config: model='RedHatAI/DeepSeek-R1-0528-quantized.w4a16', speculative_config=None, tokenizer='RedHatAI/DeepSeek-R1-0528-quantized.w4a16', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=9216, download_dir=None, load_format=auto, tensor_parallel_size=4, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=compressed-tensors, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=default, enable_prefix_caching=True, chunked_prefill_enabled=True, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":null,"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"cudagraph_mode":0,"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"use_inductor_graph_partition":false,"pass_config":{},"max_capture_size":0,"local_cache_dir":null}, 
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [dump_input.py:76] Dumping scheduler output for model execution: SchedulerOutput(scheduled_new_reqs=[NewRequestData(req_id=cmpl-f5ac59d4eaec4aacb458a90c6fc55958-0,prompt_token_ids_len=2175,mm_features=[],sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.95, top_k=0, min_p=0.008, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=True, ignore_eos=False, max_tokens=1024, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, structured_outputs=None, extra_args=None),block_ids=([1, 2, 3, 4, 5, 6, 7, 8, 440],),num_computed_tokens=2048,lora_request=None,prompt_embeds_shape=None), NewRequestData(req_id=cmpl-41c9be6bbeab494497ce127c8aa5b39e-0,prompt_token_ids_len=3026,mm_features=[],sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.95, top_k=0, min_p=0.008, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=True, ignore_eos=False, max_tokens=1024, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, structured_outputs=None, extra_args=None),block_ids=([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 441],),num_computed_tokens=2816,lora_request=None,prompt_embeds_shape=None), NewRequestData(req_id=cmpl-27b0cca15bee4ed8a42c7b5a28f1aa76-0,prompt_token_ids_len=991,mm_features=[],sampling_params=SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.95, top_k=0, min_p=0.008, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=True, ignore_eos=False, max_tokens=1024, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, structured_outputs=None, extra_args=None),block_ids=([1, 2, 3, 442],),num_computed_tokens=768,lora_request=None,prompt_embeds_shape=None)], scheduled_cached_reqs=CachedRequestData(req_ids=['cmpl-f7450937f6fa4d22a37c66983ae4f8e8-0', 'cmpl-5635fcef12b74537ac52fc5393adf805-0', 'cmpl-a31e50b17a3e42d1a3b69d1dcb0233a0-0', 'cmpl-da0df0ad6e4e48a2bed1f127acb22f92-0', 'cmpl-ebd6bf35080745dc85bd2013d0a63787-0', 'cmpl-8a3c673f5e9e4a94adbaca2406adf82d-0', 'cmpl-ebc1bf8cbeb0452b9723895bccaecfe2-0', 'cmpl-c5bc805bc5974aafa7163bcf8fffc6f2-0', 'cmpl-f6677c5ca7d748a6885a7af8f4141e10-0', 'cmpl-2bff7ac9c92b49f0be85618a8262b651-0', 'cmpl-968674a79646462eb9027ca32e021787-0', 'cmpl-593350182dfa4625bf7820c599bfe8cf-0', 'cmpl-ffde80ce066b4f879d72061a26f05ab1-0', 'cmpl-65542aebb0c642b1aa35c2cfc77242de-0', 'cmpl-d28ac303b83b452381e3e38aa93260c4-0', 'cmpl-8d2445487dcf4c7ca704c2a358fbfe20-0', 'cmpl-549f6b8d1c5f4fb3a162d2fbfeebc0ae-0', 'cmpl-7c2f775af52c4f6eb941c82fdfde8250-0', 'cmpl-de7c0d8427d0489ebb496c8f56dcd8ce-0', 'cmpl-34733243eceb4bf98ea052128ac1638c-0', 'cmpl-3d56a08d06874f6f81e8e2d830583059-0', 'cmpl-d974e29724a9427a83b25957847a0659-0', 'cmpl-cc5d9f51236243c2be791a7ff1ffd36c-0', 'cmpl-e85ed326438e4c55885c68535fec2e49-0', 'cmpl-6d711b76f82f4c939ba9c1323ef1810a-0', 'cmpl-3991a6a2a09d4e389662b90a169f9a27-0', 'cmpl-0645cf1a26ed4c899e1771f7a1e445de-0', 'cmpl-2b5b980bdfbc411abe0f1f66726d2cb5-0', 'cmpl-6f7e796fae8e4ea4af2b35043d7473e0-0', 'cmpl-405520c2280041b0b042de7abc46899f-0', 'cmpl-e4700caaf0ca4ea5bc5e071085dd1de5-0', 'cmpl-745685799ae64b27ac31882ef58cac9e-0', 'cmpl-fdbeca9dc4c94baba0e6433f9324336a-0', 'cmpl-e7ed729df65847718052c001f6958e47-0', 'cmpl-7eed2976fe094c5fb67517199f6c4267-0', 'cmpl-de25651eb1c347059feff1787602a45b-0', 'cmpl-41f2d48dbad84cfd9c627ce5c26ef4be-0', 'cmpl-c0b39e566c214e019c7793abedd477a8-0', 'cmpl-f5613b80dd5a4f09860a6e9cbe2b4aa7-0', 'cmpl-6e78256b7f2b44eeae07f6b688f293a8-0', 'cmpl-24dacd4042aa46bf92dc8b1f56deec7c-0', 'cmpl-510183b7830d4d248d23facc5ceed535-0', 'cmpl-276c6326d9294d01954d2dffaa6471a3-0', 'cmpl-5feb9d732d974fc08fbb78a6e80ec7e4-0', 'cmpl-7ec90280c9d04877bba9ac885fe55503-0', 'cmpl-223e1bd125c942d09fedcafb4ed97d6d-0', 'cmpl-0b22d330e7b442ac89d776d0a81e1795-0', 'cmpl-d8f1b6c912014afd99f82b6fe5fe9e1a-0', 'cmpl-cde514fc6fa644afade67568503504b4-0', 'cmpl-79db484ed6454c8fb4832a2fd9222311-0', 'cmpl-46ba1ffcacd642108d138f0f8c315857-0', 'cmpl-6b917d692b9342238c37ad6df2b43a85-0', 'cmpl-64839dfdc3a748059c6b4c4815fc313d-0', 'cmpl-5119073cdafc4afa8e523a675e8670e9-0', 'cmpl-2c2ef17c45e245cdb293dc1f7baeeac9-0', 'cmpl-52865351143948ce9bfc0a7e3b5f4bd7-0', 'cmpl-0f7e0bba18b8485b8d8fbb8d9f4ec7b2-0', 'cmpl-0beab308c6304a8f86c9dca844f8aefc-0', 'cmpl-96dfe4dc9f9b454b84176c5e63085bd4-0', 'cmpl-9a1f25461cc74d79a21536f6bdfd5532-0', 'cmpl-2699220217f74e238e9ad6ba7aebee3b-0', 'cmpl-2be5a549e6c449c299e3fdffa1d7bf62-0', 'cmpl-98c58b22f53b47af80443ede7385a479-0', 'cmpl-202825959e2e40dea8d2ed14e3917332-0', 'cmpl-ed2a7b3138b14decb65636c7818e04fd-0', 'cmpl-ebf19114caf94e9fa6aa207300f7ea7d-0', 'cmpl-76930fa2b4cb4175a06c7b537b6d776f-0', 'cmpl-cc5af4ec7fc8480c91d7a06d7584247e-0', 'cmpl-b1a9c6f7c190432a91166682f5f3022a-0', 'cmpl-3751f2f408da49bda23a2c7d0aca0278-0', 'cmpl-7e38164df20d4c47b4e48158721dd9f0-0', 'cmpl-0ab07f1c24744c1d8bb17ff14fcb2394-0', 'cmpl-399051b2633a4ca5941a209e1206f6e1-0', 'cmpl-5e3bfbd0e4fb4894b02db20fd02c52da-0', 'cmpl-e18db0a44bd84ac4ab24b014deb6bf15-0', 'cmpl-784d99073cca4c49b24aa613d70cb908-0', 'cmpl-936bcdd49b0248b88ec89de0aba49424-0', 'cmpl-7d3becac09c943c283bfd072602dbc11-0', 'cmpl-e1887d7754164c33a5af0dd2293ffe3c-0', 'cmpl-32edd36530344c3a978cee2d899e15d1-0', 'cmpl-dd38afd5d75745e2888c6d4152324b5b-0', 'cmpl-712a589fa4da4bc0be6fac6edbed495a-0', 'cmpl-b583a281c3f1498fb1f1ee50a57591da-0', 'cmpl-724f58b721664bf69e3d5671ca497f6b-0', 'cmpl-151ed815662d4c8a9b01cdb5fe42fdc8-0', 'cmpl-168a0c41f3274625bb2856ed35a65c1e-0', 'cmpl-a98169c29e054617af65ec7c5e26f420-0', 'cmpl-90a6f78bbb1b49fdafaa19ee20ffd47e-0', 'cmpl-037a29590615415fbfd6d02ff3d69503-0', 'cmpl-892fdc93d53b4962a794f126536b1ed6-0', 'cmpl-64924e4680b64753afd3aa48e2b5bc19-0', 'cmpl-24c639dfd232461489653b8037271074-0', 'cmpl-ea031d7286a6476aa4b0317246193312-0', 'cmpl-965fde156df546359b4d808ffa98f2fe-0', 'cmpl-f31d5f5eb7eb44bd871627f1ec8b32da-0', 'cmpl-a6fcce63f66f478e866830863703fbd3-0', 'cmpl-8f252069eba44086a5622529c361e05f-0', 'cmpl-a4d074a8c94b410ebab57a40d4411a88-0', 'cmpl-cba6d4f84c3b4ff9bca09b36a9a87151-0', 'cmpl-48fb9e4ea3f34c4d949fcae8738f3761-0', 'cmpl-75db1684de404d4bafea58141fde0df6-0', 'cmpl-d04d7e2e4f554aa09c5502487e00ae3d-0', 'cmpl-5ae4197173994018ba35c07fbe0c1aa4-0', 'cmpl-da3ad1b653ca4c1c8768568d6d78f705-0', 'cmpl-b5e7080a1c734aafbb3887c8c98d3632-0', 'cmpl-7251a7a34e6e4a7aac7559a4bf8bb8c1-0', 'cmpl-b1954962a3b5462fa335c6948972e996-0', 'cmpl-ae62b7d6aa6c4a40ae545352bfc5660a-0', 'cmpl-0dc9fd5a9edc49ce97a209f8fe225483-0', 'cmpl-0e0e8242187a4eefafe62e87f2f084ec-0', 'cmpl-a5ab6237f4e1471db2f772e7bb459c07-0', 'cmpl-5de5fdabd46f46f59fa7586491af43f0-0', 'cmpl-df8e46a726694df9b2cc796e44de2183-0', 'cmpl-83214304d5a14046861f07086c350bf1-0', 'cmpl-882e98815829419695b5cf159b343730-0', 'cmpl-bd67b762ab9f4be9a4cd799dfdfb5054-0', 'cmpl-597a4536254f4473982dc651713fd437-0', 'cmpl-34ad840bed994cbe87aee7ed3c3234fd-0', 'cmpl-035eb1ecfdb34bb1a156aed91625babc-0', 'cmpl-aa659be37dd14640a80e613052b6a15c-0', 'cmpl-3d5f56c46211499c98d9179b30d7dc9a-0', 'cmpl-3456b9e97ddf4733946cbca2909ba789-0', 'cmpl-fbb7074cd76e495f99305281c9bcde36-0', 'cmpl-61987320a7854b018c67ea74db4a74bd-0', 'cmpl-5d0aa92b981d49da96172425731d658e-0', 'cmpl-35872183b9eb42bab9e36478b003bc3c-0'], resumed_from_preemption=[false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false, false], new_token_ids=[], new_block_ids=[null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null], num_computed_tokens=[1180, 1722, 3108, 2420, 1372, 1590, 2319, 1322, 1870, 1231, 1689, 1589, 1589, 1445, 1627, 2318, 1150, 2419, 1687, 3268, 1867, 2405, 2034, 1299, 1441, 3103, 4041, 1175, 1163, 6886, 1157, 1404, 1331, 3090, 2168, 3083, 1296, 1126, 1122, 1378, 3236, 1313, 1529, 1191, 2276, 1524, 1217, 1674, 1635, 1357, 1348, 1561, 1620, 1235, 1085, 1377, 1276, 1292, 1224, 3209, 1093, 1261, 1078, 1596, 3174, 2315, 2213, 1265, 1047, 1191, 1609, 1919, 1916, 1180, 1192, 1380, 2277, 2158, 1588, 1188, 2900, 1011, 1202, 1205, 1177, 1539, 1361, 1440, 1569, 997, 1419, 1316, 2022, 1048, 990, 1695, 2230, 1381, 1630, 1025, 1025, 1495, 1977, 1138, 1233, 1297, 940, 939, 1314, 1101, 1599, 1149, 1465, 1592, 3259, 2188, 940, 1044, 3954, 1988, 1389, 3795, 3802, 1626, 2175, 1664]), num_scheduled_tokens={cmpl-9a1f25461cc74d79a21536f6bdfd5532-0: 1, cmpl-90a6f78bbb1b49fdafaa19ee20ffd47e-0: 1, cmpl-202825959e2e40dea8d2ed14e3917332-0: 1, cmpl-510183b7830d4d248d23facc5ceed535-0: 1, cmpl-724f58b721664bf69e3d5671ca497f6b-0: 1, cmpl-e7ed729df65847718052c001f6958e47-0: 1, cmpl-24dacd4042aa46bf92dc8b1f56deec7c-0: 1, cmpl-5feb9d732d974fc08fbb78a6e80ec7e4-0: 1, cmpl-cde514fc6fa644afade67568503504b4-0: 1, cmpl-dd38afd5d75745e2888c6d4152324b5b-0: 1, cmpl-ebf19114caf94e9fa6aa207300f7ea7d-0: 1, cmpl-7d3becac09c943c283bfd072602dbc11-0: 1, cmpl-f31d5f5eb7eb44bd871627f1ec8b32da-0: 1, cmpl-037a29590615415fbfd6d02ff3d69503-0: 1, cmpl-0dc9fd5a9edc49ce97a209f8fe225483-0: 1, cmpl-da3ad1b653ca4c1c8768568d6d78f705-0: 1, cmpl-cc5af4ec7fc8480c91d7a06d7584247e-0: 1, cmpl-5635fcef12b74537ac52fc5393adf805-0: 1, cmpl-5e3bfbd0e4fb4894b02db20fd02c52da-0: 1, cmpl-24c639dfd232461489653b8037271074-0: 1, cmpl-da0df0ad6e4e48a2bed1f127acb22f92-0: 1, cmpl-276c6326d9294d01954d2dffaa6471a3-0: 1, cmpl-d04d7e2e4f554aa09c5502487e00ae3d-0: 1, cmpl-5de5fdabd46f46f59fa7586491af43f0-0: 1, cmpl-48fb9e4ea3f34c4d949fcae8738f3761-0: 1, cmpl-405520c2280041b0b042de7abc46899f-0: 1, cmpl-f5613b80dd5a4f09860a6e9cbe2b4aa7-0: 1, cmpl-79db484ed6454c8fb4832a2fd9222311-0: 1, cmpl-6b917d692b9342238c37ad6df2b43a85-0: 1, cmpl-593350182dfa4625bf7820c599bfe8cf-0: 1, cmpl-e1887d7754164c33a5af0dd2293ffe3c-0: 1, cmpl-76930fa2b4cb4175a06c7b537b6d776f-0: 1, cmpl-bd67b762ab9f4be9a4cd799dfdfb5054-0: 1, cmpl-96dfe4dc9f9b454b84176c5e63085bd4-0: 1, cmpl-3751f2f408da49bda23a2c7d0aca0278-0: 1, cmpl-61987320a7854b018c67ea74db4a74bd-0: 1, cmpl-0f7e0bba18b8485b8d8fbb8d9f4ec7b2-0: 1, cmpl-5d0aa92b981d49da96172425731d658e-0: 1, cmpl-41f2d48dbad84cfd9c627ce5c26ef4be-0: 1, cmpl-46ba1ffcacd642108d138f0f8c315857-0: 1, cmpl-b583a281c3f1498fb1f1ee50a57591da-0: 1, cmpl-892fdc93d53b4962a794f126536b1ed6-0: 1, cmpl-fdbeca9dc4c94baba0e6433f9324336a-0: 1, cmpl-2c2ef17c45e245cdb293dc1f7baeeac9-0: 1, cmpl-b1a9c6f7c190432a91166682f5f3022a-0: 1, cmpl-aa659be37dd14640a80e613052b6a15c-0: 1, cmpl-df8e46a726694df9b2cc796e44de2183-0: 1, cmpl-8d2445487dcf4c7ca704c2a358fbfe20-0: 1, cmpl-d974e29724a9427a83b25957847a0659-0: 1, cmpl-5119073cdafc4afa8e523a675e8670e9-0: 1, cmpl-0ab07f1c24744c1d8bb17ff14fcb2394-0: 1, cmpl-34ad840bed994cbe87aee7ed3c3234fd-0: 1, cmpl-d28ac303b83b452381e3e38aa93260c4-0: 1, cmpl-a6fcce63f66f478e866830863703fbd3-0: 1, cmpl-882e98815829419695b5cf159b343730-0: 1, cmpl-2bff7ac9c92b49f0be85618a8262b651-0: 1, cmpl-b5e7080a1c734aafbb3887c8c98d3632-0: 1, cmpl-35872183b9eb42bab9e36478b003bc3c-0: 1, cmpl-98c58b22f53b47af80443ede7385a479-0: 1, cmpl-8a3c673f5e9e4a94adbaca2406adf82d-0: 1, cmpl-968674a79646462eb9027ca32e021787-0: 1, cmpl-0645cf1a26ed4c899e1771f7a1e445de-0: 1, cmpl-e4700caaf0ca4ea5bc5e071085dd1de5-0: 1, cmpl-7ec90280c9d04877bba9ac885fe55503-0: 1, cmpl-712a589fa4da4bc0be6fac6edbed495a-0: 1, cmpl-784d99073cca4c49b24aa613d70cb908-0: 1, cmpl-75db1684de404d4bafea58141fde0df6-0: 1, cmpl-41c9be6bbeab494497ce127c8aa5b39e-0: 210, cmpl-c0b39e566c214e019c7793abedd477a8-0: 1, cmpl-3d5f56c46211499c98d9179b30d7dc9a-0: 1, cmpl-223e1bd125c942d09fedcafb4ed97d6d-0: 1, cmpl-ae62b7d6aa6c4a40ae545352bfc5660a-0: 1, cmpl-83214304d5a14046861f07086c350bf1-0: 1, cmpl-6e78256b7f2b44eeae07f6b688f293a8-0: 1, cmpl-7251a7a34e6e4a7aac7559a4bf8bb8c1-0: 1, cmpl-3d56a08d06874f6f81e8e2d830583059-0: 1, cmpl-ebc1bf8cbeb0452b9723895bccaecfe2-0: 1, cmpl-2be5a549e6c449c299e3fdffa1d7bf62-0: 1, cmpl-168a0c41f3274625bb2856ed35a65c1e-0: 1, cmpl-a5ab6237f4e1471db2f772e7bb459c07-0: 1, cmpl-de7c0d8427d0489ebb496c8f56dcd8ce-0: 1, cmpl-a4d074a8c94b410ebab57a40d4411a88-0: 1, cmpl-399051b2633a4ca5941a209e1206f6e1-0: 1, cmpl-34733243eceb4bf98ea052128ac1638c-0: 1, cmpl-8f252069eba44086a5622529c361e05f-0: 1, cmpl-fbb7074cd76e495f99305281c9bcde36-0: 1, cmpl-7c2f775af52c4f6eb941c82fdfde8250-0: 1, cmpl-745685799ae64b27ac31882ef58cac9e-0: 1, cmpl-7eed2976fe094c5fb67517199f6c4267-0: 1, cmpl-52865351143948ce9bfc0a7e3b5f4bd7-0: 1, cmpl-6f7e796fae8e4ea4af2b35043d7473e0-0: 1, cmpl-035eb1ecfdb34bb1a156aed91625babc-0: 1, cmpl-64839dfdc3a748059c6b4c4815fc313d-0: 1, cmpl-a98169c29e054617af65ec7c5e26f420-0: 1, cmpl-ea031d7286a6476aa4b0317246193312-0: 1, cmpl-ebd6bf35080745dc85bd2013d0a63787-0: 1, cmpl-5ae4197173994018ba35c07fbe0c1aa4-0: 1, cmpl-d8f1b6c912014afd99f82b6fe5fe9e1a-0: 1, cmpl-c5bc805bc5974aafa7163bcf8fffc6f2-0: 1, cmpl-7e38164df20d4c47b4e48158721dd9f0-0: 1, cmpl-0e0e8242187a4eefafe62e87f2f084ec-0: 1, cmpl-cba6d4f84c3b4ff9bca09b36a9a87151-0: 1, cmpl-f7450937f6fa4d22a37c66983ae4f8e8-0: 1, cmpl-3456b9e97ddf4733946cbca2909ba789-0: 1, cmpl-597a4536254f4473982dc651713fd437-0: 1, cmpl-b1954962a3b5462fa335c6948972e996-0: 1, cmpl-2b5b980bdfbc411abe0f1f66726d2cb5-0: 1, cmpl-32edd36530344c3a978cee2d899e15d1-0: 1, cmpl-a31e50b17a3e42d1a3b69d1dcb0233a0-0: 1, cmpl-6d711b76f82f4c939ba9c1323ef1810a-0: 1, cmpl-f5ac59d4eaec4aacb458a90c6fc55958-0: 127, cmpl-e18db0a44bd84ac4ab24b014deb6bf15-0: 1, cmpl-965fde156df546359b4d808ffa98f2fe-0: 1, cmpl-936bcdd49b0248b88ec89de0aba49424-0: 1, cmpl-e85ed326438e4c55885c68535fec2e49-0: 1, cmpl-151ed815662d4c8a9b01cdb5fe42fdc8-0: 1, cmpl-27b0cca15bee4ed8a42c7b5a28f1aa76-0: 223, cmpl-0b22d330e7b442ac89d776d0a81e1795-0: 1, cmpl-de25651eb1c347059feff1787602a45b-0: 1, cmpl-0beab308c6304a8f86c9dca844f8aefc-0: 1, cmpl-2699220217f74e238e9ad6ba7aebee3b-0: 1, cmpl-f6677c5ca7d748a6885a7af8f4141e10-0: 1, cmpl-64924e4680b64753afd3aa48e2b5bc19-0: 1, cmpl-cc5d9f51236243c2be791a7ff1ffd36c-0: 1, cmpl-549f6b8d1c5f4fb3a162d2fbfeebc0ae-0: 1, cmpl-65542aebb0c642b1aa35c2cfc77242de-0: 1, cmpl-3991a6a2a09d4e389662b90a169f9a27-0: 1, cmpl-ffde80ce066b4f879d72061a26f05ab1-0: 1, cmpl-ed2a7b3138b14decb65636c7818e04fd-0: 1}, total_num_scheduled_tokens=686, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, num_common_prefix_blocks=[3], finished_req_ids=[], free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=null, kv_connector_metadata=null)
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [dump_input.py:79] Dumping scheduler stats: SchedulerStats(num_running_reqs=127, num_waiting_reqs=0, step_counter=0, current_wave=0, kv_cache_usage=0.02975420439844756, prefix_cache_stats=PrefixCacheStats(reset=False, requests=0, queries=0, hits=0), spec_decoding_stats=None, kv_connector_stats=None, num_corrupted_reqs=0)
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710] EngineCore encountered a fatal error.
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710] Traceback (most recent call last):
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]   File "/root/vllm/vllm/v1/engine/core.py", line 701, in run_engine_core
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]     engine_core.run_busy_loop()
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]   File "/root/vllm/vllm/v1/engine/core.py", line 728, in run_busy_loop
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]     self._process_engine_step()
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]   File "/root/vllm/vllm/v1/engine/core.py", line 754, in _process_engine_step
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]     outputs, model_executed = self.step_fn()
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]                               ^^^^^^^^^^^^^^
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]   File "/root/vllm/vllm/v1/engine/core.py", line 346, in step_with_batch_queue
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]     model_output = self.execute_model_with_error_logging(
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]   File "/root/vllm/vllm/v1/engine/core.py", line 270, in execute_model_with_error_logging
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]     raise err
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]   File "/root/vllm/vllm/v1/engine/core.py", line 261, in execute_model_with_error_logging
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]     return model_fn(scheduler_output)
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]            ^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]   File "/root/vllm/vllm/v1/engine/core.py", line 347, in <lambda>
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]     lambda _: future.result(), scheduler_output)
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]               ^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]   File "/usr/lib/python3.12/concurrent/futures/_base.py", line 456, in result
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]     return self.__get_result()
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]            ^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]   File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]     raise self._exception
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]   File "/usr/lib/python3.12/concurrent/futures/thread.py", line 58, in run
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]     result = self.fn(*self.args, **self.kwargs)
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]   File "/root/vllm/vllm/v1/executor/multiproc_executor.py", line 248, in get_response
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710]     raise RuntimeError(
(EngineCore_DP0 pid=86780) ERROR 10-02 09:11:38 [core.py:710] RuntimeError: Worker failed with error 'shape mismatch: value tensor of shape [33, 4] cannot be broadcast to indexing result of shape [129, 1]', please check the stack trace above for the root cause
(Worker_TP0 pid=86929) INFO 10-02 09:11:38 [multiproc_executor.py:558] Parent process exited, terminating worker
(APIServer pid=86613) ERROR 10-02 09:11:38 [async_llm.py:480] AsyncLLM output_handler failed.
(APIServer pid=86613) ERROR 10-02 09:11:38 [async_llm.py:480] Traceback (most recent call last):
(APIServer pid=86613) ERROR 10-02 09:11:38 [async_llm.py:480]   File "/root/vllm/vllm/v1/engine/async_llm.py", line 439, in output_handler
(APIServer pid=86613) ERROR 10-02 09:11:38 [async_llm.py:480]     outputs = await engine_core.get_output_async()
(APIServer pid=86613) ERROR 10-02 09:11:38 [async_llm.py:480]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=86613) ERROR 10-02 09:11:38 [async_llm.py:480]   File "/root/vllm/vllm/v1/engine/core_client.py", line 846, in get_output_async
(APIServer pid=86613) ERROR 10-02 09:11:38 [async_llm.py:480]     raise self._format_exception(outputs) from None
(APIServer pid=86613) ERROR 10-02 09:11:38 [async_llm.py:480] vllm.v1.engine.exceptions.EngineDeadError: EngineCore encountered an issue. See stack trace (above) for the root cause.
(Worker_TP3 pid=86932) INFO 10-02 09:11:38 [multiproc_executor.py:558] Parent process exited, terminating worker
(Worker_TP1 pid=86930) INFO 10-02 09:11:38 [multiproc_executor.py:558] Parent process exited, terminating worker
(Worker_TP2 pid=86931) INFO 10-02 09:11:38 [multiproc_executor.py:558] Parent process exited, terminating worker

Notes:

  • For some reason I had to put this line within a with torch.no_grad():, otherwise I got an error like RuntimeError: sum(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.
  • I also wrapped this line with a try/except that falls back to kv_cache_stride_order = tuple(range(len(kv_cache_shape))). This was a guess based on what I saw in vllm/v1/worker/gpu_model_runner.py - apologies if this guess was wrong and is the cause of this issue.
  • It's possible that the underlying cause of this issue is the DCP implementation itself, since I'm observing some strange behavior when testing DCP in 0.10.2 (and on latest main branch, and at ac201a0 when it DCP support was merged). Specifically, prefix retrieval and/or storage is seemingly getting mixed up between requests during high concurrency (but there is no crash), resulting in outputs that seem 'coherent', but are clearly for the wrong request. That problem goes away when adding --no-enable-prefix-caching, or removing -dcp 4. CC @youzhedian

@ProExpertProg
Copy link
Collaborator

  • For some reason I had to put this line within a with torch.no_grad():, otherwise I got an error like RuntimeError: sum(): functions with out=... arguments don't support automatic differentiation, but one of the arguments requires grad.

Is it possible you're missing a @torch.inference_mode()?

@WoosukKwon WoosukKwon force-pushed the woosuk/model-runner-v2 branch from cf0666e to 09e4b2f Compare October 30, 2025 23:30
@mergify mergify bot added the ci/build label Oct 30, 2025
@mergify
Copy link

mergify bot commented Oct 30, 2025

Documentation preview: https://vllm--25266.org.readthedocs.build/en/25266/

@mergify mergify bot added the documentation Improvements or additions to documentation label Oct 30, 2025
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@youzhedian
Copy link
Contributor

youzhedian commented Oct 31, 2025

  • It's possible that the underlying cause of this issue is the DCP implementation itself, since I'm observing some strange behavior when testing DCP in 0.10.2 (and on latest main branch, and at ac201a0 when it DCP support was merged). Specifically, prefix retrieval and/or storage is seemingly getting mixed up between requests during high concurrency (but there is no crash), resulting in outputs that seem 'coherent', but are clearly for the wrong request. That problem goes away when adding --no-enable-prefix-caching, or removing -dcp 4.

@josephrocca there are two dcp bugfix !26296 and !27518 (still open) , your case should be fixed by !26296, you can try cherrypick these fixes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC]: Redesigning Persistent Batch in vLLM