@@ -450,6 +450,9 @@ def forward(self, *args, **kwargs):
450
450
selected_token_indices )
451
451
return hidden_states
452
452
453
+ def __getattr__ (self , attr : str ):
454
+ return getattr (self .model , attr )
455
+
453
456
def compute_logits (self , * args , ** kwargs ):
454
457
return self .model .compute_logits (* args , ** kwargs )
455
458
@@ -543,6 +546,7 @@ class ModelInputForHPU(ModelRunnerInputBase):
543
546
async_callback : Optional [Callable ] = None
544
547
is_first_multi_step : bool = True
545
548
is_last_step : bool = True
549
+ previous_hidden_states : Optional [torch .Tensor ] = None
546
550
547
551
def as_broadcastable_tensor_dict (self ) -> Dict [str , Any ]:
548
552
tensor_dict = {
@@ -645,13 +649,17 @@ def __init__(
645
649
self .pin_memory = is_pin_memory_available ()
646
650
self .kv_cache_dtype = self .cache_config .cache_dtype
647
651
652
+ num_attn_heads = self .model_config .get_num_attention_heads (
653
+ self .parallel_config )
654
+ needs_attn_backend = (num_attn_heads != 0
655
+ or self .model_config .is_attention_free )
648
656
self .attn_backend = get_attn_backend (
649
657
self .model_config .get_head_size (),
650
658
self .model_config .dtype ,
651
659
self .kv_cache_dtype ,
652
660
self .block_size ,
653
661
self .model_config .is_attention_free ,
654
- )
662
+ ) if needs_attn_backend else None
655
663
656
664
# Lazy initialization
657
665
self .lora_manager : LRUCacheWorkerLoRAManager = None
@@ -666,13 +674,29 @@ def __init__(
666
674
self .bucketing_global_state = HPUBucketingGlobalState ()
667
675
self ._setup_buckets ()
668
676
self ._set_gc_threshold ()
677
+ if vllm_config .speculative_config is not None \
678
+ and vllm_config .scheduler_config .num_scheduler_steps != 1 :
679
+ # Speculative decoding is not supported with multi-step scheduling
680
+ raise ValueError (
681
+ "Speculative decoding is not supported with multi-step "
682
+ "scheduling. Please set num_scheduler_steps to 1." )
669
683
self .use_contiguous_pa = envs .VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
684
+ if vllm_config .speculative_config is not None \
685
+ and self .use_contiguous_pa :
686
+ logger .warning ("Speculative decoding is not supported with "
687
+ "contiguous PA, set VLLM_CONTIGUOUS_PA to false" )
688
+ self .use_contiguous_pa = False
689
+ self .model_type = self .model_config .hf_config .model_type
690
+ if self .model_type in ("medusa" , "mlp_speculator" , "eagle" ):
691
+ self .skip_warmup = True
670
692
671
693
# For multi-step scheduling
672
694
self .cached_step_outputs : List [torch .Tensor ] = []
673
695
# For delayed sampling
674
696
self .cached_step_inputs : List [
675
697
ModelInputForHPUWithSamplingMetadata ] = []
698
+ self .spec_decode_enabled = \
699
+ self .vllm_config .speculative_config is not None
676
700
677
701
def _set_gc_threshold (self ) -> None :
678
702
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
@@ -1496,10 +1520,30 @@ def warmup_scenario(self,
1496
1520
profiler .start ()
1497
1521
for _ in range (times ):
1498
1522
inputs = self .prepare_model_input (seqs )
1523
+ # in case spec decode, prepare dummy previous_hidden_states
1524
+ additional_inputs = {}
1525
+ if self .model_type in ("medusa" , "mlp_speculator" , "eagle" ,
1526
+ "deepseek_mtp" ):
1527
+ input_tokens = inputs .input_tokens
1528
+ assert input_tokens is not None
1529
+ bs = input_tokens .shape [0 ]
1530
+ hidden_size = self .model_config .get_hidden_size ()
1531
+
1532
+ previous_hidden_states = torch .zeros (
1533
+ (bs , hidden_size ),
1534
+ device = input_tokens .device ,
1535
+ dtype = self .model_config .dtype )
1536
+ additional_inputs = {
1537
+ "previous_hidden_states" : previous_hidden_states
1538
+ }
1539
+
1499
1540
is_single_step = \
1500
1541
self .vllm_config .scheduler_config .num_scheduler_steps == 1
1501
1542
if is_prompt or is_single_step :
1502
- self .execute_model (inputs , None , warmup_mode = True )
1543
+ self .execute_model (inputs ,
1544
+ None ,
1545
+ warmup_mode = True ,
1546
+ ** additional_inputs )
1503
1547
else : # decode with multi-step
1504
1548
inputs = dataclasses .replace (inputs ,
1505
1549
is_first_multi_step = True ,
@@ -2055,11 +2099,14 @@ def execute_model(
2055
2099
num_steps : int = 1 ,
2056
2100
warmup_mode = False ,
2057
2101
seqs = None ,
2102
+ ** kwargs ,
2058
2103
) -> Optional [Union [List [SamplerOutput ], IntermediateTensors ]]:
2059
2104
VLLM_DELAYED_SAMPLING = envs .VLLM_HPU_USE_DELAYED_SAMPLING
2060
2105
use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
2061
2106
assert not (use_delayed_sampling and num_steps != 1 ), \
2062
2107
'Delayed sampling is not compatible with MSS!'
2108
+ assert not (use_delayed_sampling and self .spec_decode_enabled ), \
2109
+ 'Delayed sampling is not compatible with speculative decoding!'
2063
2110
assert model_input .input_tokens is not None
2064
2111
if use_delayed_sampling and not model_input .is_prompt and \
2065
2112
self .is_driver_worker :
@@ -2087,6 +2134,7 @@ def execute_model(
2087
2134
0 , target_indices , self .cached_step_outputs [i ])
2088
2135
htorch .core .mark_step ()
2089
2136
2137
+ previous_hidden_states = kwargs .get ('previous_hidden_states' )
2090
2138
if not model_input .is_first_multi_step :
2091
2139
if not model_input .is_last_step :
2092
2140
# not first or last multi-step
@@ -2150,13 +2198,32 @@ def execute_model(
2150
2198
"virtual_engine" : model_input .virtual_engine ,
2151
2199
** (model_input .multi_modal_kwargs or {}),
2152
2200
}
2201
+ if previous_hidden_states is not None :
2202
+ # HPU will pad up to block_size,
2203
+ # pad previous_hidden_states as well
2204
+ previous_hidden_states = previous_hidden_states .unsqueeze (
2205
+ 1 ).expand (- 1 , input_tokens .shape [- 1 ], - 1 )
2206
+ batch_size_padding = batch_size - previous_hidden_states .shape [
2207
+ 0 ]
2208
+ if batch_size_padding > 0 :
2209
+ dummy_previous_hidden_states = torch .zeros (
2210
+ batch_size_padding ,
2211
+ * previous_hidden_states .shape [1 :],
2212
+ dtype = previous_hidden_states .dtype ,
2213
+ device = previous_hidden_states .device )
2214
+ previous_hidden_states = torch .cat (
2215
+ [previous_hidden_states , dummy_previous_hidden_states ],
2216
+ dim = 0 )
2217
+ execute_model_kwargs .update (
2218
+ {"previous_hidden_states" : previous_hidden_states })
2153
2219
if htorch .utils .internal .is_lazy ():
2154
2220
execute_model_kwargs .update (
2155
2221
{"bypass_hpu_graphs" : not use_graphs })
2156
2222
2157
2223
htorch .core .mark_step ()
2158
2224
if self .is_driver_worker :
2159
2225
model_event_name = ("model_"
2226
+ f"{ self .model_type } _"
2160
2227
f"{ 'prompt' if is_prompt else 'decode' } _"
2161
2228
f"bs{ batch_size } _"
2162
2229
f"seq{ seq_len } _"
@@ -2211,6 +2278,7 @@ def try_revert_dummy_output_tokens():
2211
2278
with self .profiler .record_event (
2212
2279
'internal' ,
2213
2280
('compute_logits_'
2281
+ f"{ self .model_type } _"
2214
2282
f'{ "prompt" if is_prompt else "decode" } _bs'
2215
2283
f'{ batch_size } _'
2216
2284
f'seq{ seq_len } ' )):
@@ -2228,6 +2296,7 @@ def try_revert_dummy_output_tokens():
2228
2296
2229
2297
with self .profiler .record_event (
2230
2298
'internal' , ('sample_'
2299
+ f"{ self .model_type } _"
2231
2300
f'{ "prompt" if is_prompt else "decode" } _'
2232
2301
f'bs{ batch_size } _'
2233
2302
f'seq{ seq_len } ' )):
@@ -2319,9 +2388,18 @@ def try_revert_dummy_output_tokens():
2319
2388
is_prompt = is_prompt )
2320
2389
self .profiler .record_counter (self .event_start , counters )
2321
2390
if num_steps == 1 :
2322
- if self .return_hidden_states :
2323
- # we only need to pass hidden states of most recent token
2391
+ if self .spec_decode_enabled and isinstance (
2392
+ output , SamplerOutput ):
2393
+ output .sampled_token_ids = output .sampled_token_ids [:
2394
+ real_batch_size ]
2395
+ output .sampled_token_probs = output .sampled_token_probs [:
2396
+ real_batch_size ]
2397
+ output .logprobs = output .logprobs [:real_batch_size ]
2398
+ if self .return_hidden_states and isinstance (
2399
+ output , SamplerOutput ):
2324
2400
assert model_input .sampling_metadata is not None
2401
+ # we only need to pass hidden states of most recent token
2402
+ hidden_states = hidden_states [:real_batch_size ]
2325
2403
if model_input .is_prompt :
2326
2404
output .prefill_hidden_states = hidden_states
2327
2405
output .hidden_states = hidden_states
@@ -2330,7 +2408,6 @@ def try_revert_dummy_output_tokens():
2330
2408
return [fake_output ]
2331
2409
else :
2332
2410
return []
2333
-
2334
2411
return [output ] if self .is_driver_worker else []
2335
2412
else :
2336
2413
return []
0 commit comments