Skip to content

Commit 8b88b8a

Browse files
committed
defer return_hidden_states speculation methods
1 parent 2b472ea commit 8b88b8a

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

vllm/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,6 +1292,14 @@ def maybe_create_spec_config(
12921292
"speculative_model unless the draft model config contains an "
12931293
"n_predict parameter.")
12941294

1295+
if enable_chunked_prefill and draft_hf_config.model_type in [
1296+
"medusa", "mlp_speculator", "eagle"
1297+
]:
1298+
raise ValueError(
1299+
"Chunked prefill and hidden-state based draft models are not "
1300+
"yet compatible."
1301+
)
1302+
12951303
if typical_acceptance_sampler_posterior_threshold is None:
12961304
typical_acceptance_sampler_posterior_threshold = 0.09
12971305
if typical_acceptance_sampler_posterior_alpha is None:

vllm/spec_decode/batch_expansion.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,6 @@ def _contract_batch(
195195
else:
196196
all_hidden_states = None
197197

198-
# TODO fix with `return_hidden_states=True` where hidden states are full size,
199-
# and we'll need all indices prior to selecting `do_sample=True`,
200-
# while logits are indexed by `selected_token_indices` True
201-
202198
# Rule out prefills that are in `non_spec_indices` but produce no tokens.
203199
non_spec_indices = [
204200
idx for idx in non_spec_indices

vllm/spec_decode/spec_decode_worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,9 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
584584
hidden_states = sampler_output.hidden_states
585585
if hidden_states is not None:
586586
# remove hidden_states for prompt tokens
587+
# TODO Enable `return_hidden_states`: prefill chunks hidden states are
588+
# pruned by the logits processor. Also, they should be arranged back into
589+
# full-prefill latent. Address it to enable MLPSpeculator.
587590
if any(seq.is_prompt
588591
for seq in execute_model_req.seq_group_metadata_list):
589592
hidden_states = hidden_states[
@@ -698,7 +701,7 @@ def _run_speculative_decoding_step(
698701
# TODO skip this if chunking is not enabled
699702
if len(non_spec_indices):
700703
all_hidden_states = proposal_scores.hidden_states
701-
# TODO fix `return_hidden_states`
704+
# TODO fix `return_hidden_states`, same as in `_run_no_spec`
702705
if all_hidden_states is not None:
703706
prefill_hidden_states = all_hidden_states[non_spec_indices]
704707
execute_model_req.previous_hidden_states = prepare_prefill_hidden_states(

0 commit comments

Comments
 (0)