Skip to content

Commit 9a8bb01

Browse files
njhillabhigoyal1997
authored andcommitted
[BugFix] Fix cuda graph for MLPSpeculator (vllm-project#5875)
Co-authored-by: Abhinav Goyal <abhinav.goyal@flipkart.com> Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent c07ce3f commit 9a8bb01

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

examples/offline_inference_mlpspeculator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def time_generation(llm: LLM, prompts: List[str],
5252
speculative_model="ibm-fms/llama-13b-accelerator",
5353
# These are currently required for MLPSpeculator decoding
5454
use_v2_block_manager=True,
55-
enforce_eager=True,
5655
)
5756

5857
print("With speculation")

vllm/worker/model_runner.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,10 +1020,13 @@ def execute_model(
10201020

10211021
if self.return_hidden_states:
10221022
# we only need to pass hidden states of most recent token
1023+
assert model_input.sampling_metadata is not None
1024+
indices = model_input.sampling_metadata.selected_token_indices
10231025
if model_input.is_prompt:
1024-
assert model_input.sampling_metadata is not None
1025-
hidden_states = hidden_states.index_select(
1026-
0, model_input.sampling_metadata.selected_token_indices)
1026+
hidden_states = hidden_states.index_select(0, indices)
1027+
elif decode_meta.use_cuda_graph:
1028+
hidden_states = hidden_states[:len(indices)]
1029+
10271030
output.hidden_states = hidden_states
10281031

10291032
return output

0 commit comments

Comments
 (0)