@@ -448,6 +448,9 @@ def forward(self, *args, **kwargs):
448
448
selected_token_indices )
449
449
return hidden_states
450
450
451
+ def __getattr__ (self , attr : str ):
452
+ return getattr (self .model , attr )
453
+
451
454
def compute_logits (self , * args , ** kwargs ):
452
455
return self .model .compute_logits (* args , ** kwargs )
453
456
@@ -541,6 +544,7 @@ class ModelInputForHPU(ModelRunnerInputBase):
541
544
async_callback : Optional [Callable ] = None
542
545
is_first_multi_step : bool = True
543
546
is_last_step : bool = True
547
+ previous_hidden_states : Optional [torch .Tensor ] = None
544
548
545
549
def as_broadcastable_tensor_dict (self ) -> Dict [str , Any ]:
546
550
tensor_dict = {
@@ -643,13 +647,17 @@ def __init__(
643
647
self .pin_memory = is_pin_memory_available ()
644
648
self .kv_cache_dtype = self .cache_config .cache_dtype
645
649
650
+ num_attn_heads = self .model_config .get_num_attention_heads (
651
+ self .parallel_config )
652
+ needs_attn_backend = (num_attn_heads != 0
653
+ or self .model_config .is_attention_free )
646
654
self .attn_backend = get_attn_backend (
647
655
self .model_config .get_head_size (),
648
656
self .model_config .dtype ,
649
657
self .kv_cache_dtype ,
650
658
self .block_size ,
651
659
self .model_config .is_attention_free ,
652
- )
660
+ ) if needs_attn_backend else None
653
661
654
662
# Lazy initialization
655
663
self .lora_manager : LRUCacheWorkerLoRAManager = None
@@ -664,7 +672,21 @@ def __init__(
664
672
self .bucketing_global_state = HPUBucketingGlobalState ()
665
673
self ._setup_buckets ()
666
674
self ._set_gc_threshold ()
675
+ if vllm_config .speculative_config is not None \
676
+ and vllm_config .scheduler_config .num_scheduler_steps != 1 :
677
+ # Speculative decoding is not supported with multi-step scheduling
678
+ raise ValueError (
679
+ "Speculative decoding is not supported with multi-step "
680
+ "scheduling. Please set num_scheduler_steps to 1." )
667
681
self .use_contiguous_pa = envs .VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
682
+ if vllm_config .speculative_config is not None \
683
+ and self .use_contiguous_pa :
684
+ logger .warning ("Speculative decoding is not supported with "
685
+ "contiguous PA, set VLLM_CONTIGUOUS_PA to false" )
686
+ self .use_contiguous_pa = False
687
+ self .model_type = self .model_config .hf_config .model_type
688
+ if self .model_type in ("medusa" , "mlp_speculator" , "eagle" ):
689
+ self .skip_warmup = True
668
690
669
691
# For multi-step scheduling
670
692
self .cached_step_outputs : List [torch .Tensor ] = []
@@ -1485,10 +1507,30 @@ def warmup_scenario(self,
1485
1507
profiler .start ()
1486
1508
for _ in range (times ):
1487
1509
inputs = self .prepare_model_input (seqs )
1510
+ # in case spec decode, prepare dummy previous_hidden_states
1511
+ additional_inputs = {}
1512
+ if self .model_type in ("medusa" , "mlp_speculator" , "eagle" ,
1513
+ "deepseek_mtp" ):
1514
+ input_tokens = inputs .input_tokens
1515
+ assert input_tokens is not None
1516
+ bs = input_tokens .shape [0 ]
1517
+ hidden_size = self .model_config .get_hidden_size ()
1518
+
1519
+ previous_hidden_states = torch .zeros (
1520
+ (bs , hidden_size ),
1521
+ device = input_tokens .device ,
1522
+ dtype = self .model_config .dtype )
1523
+ additional_inputs = {
1524
+ "previous_hidden_states" : previous_hidden_states
1525
+ }
1526
+
1488
1527
is_single_step = \
1489
1528
self .vllm_config .scheduler_config .num_scheduler_steps == 1
1490
1529
if is_prompt or is_single_step :
1491
- self .execute_model (inputs , None , warmup_mode = True )
1530
+ self .execute_model (inputs ,
1531
+ None ,
1532
+ warmup_mode = True ,
1533
+ ** additional_inputs )
1492
1534
else : # decode with multi-step
1493
1535
inputs = dataclasses .replace (inputs ,
1494
1536
is_first_multi_step = True ,
@@ -2029,7 +2071,9 @@ def execute_model(
2029
2071
num_steps : int = 1 ,
2030
2072
warmup_mode = False ,
2031
2073
seqs = None ,
2074
+ ** kwargs ,
2032
2075
) -> Optional [Union [List [SamplerOutput ], IntermediateTensors ]]:
2076
+ previous_hidden_states = kwargs .get ('previous_hidden_states' )
2033
2077
if not model_input .is_first_multi_step :
2034
2078
if not model_input .is_last_step :
2035
2079
# not first or last multi-step
@@ -2079,13 +2123,32 @@ def execute_model(
2079
2123
"virtual_engine" : model_input .virtual_engine ,
2080
2124
** (model_input .multi_modal_kwargs or {}),
2081
2125
}
2126
+ if previous_hidden_states is not None :
2127
+ # HPU will pad up to block_size,
2128
+ # pad previous_hidden_states as well
2129
+ previous_hidden_states = previous_hidden_states .unsqueeze (
2130
+ 1 ).expand (- 1 , input_tokens .shape [- 1 ], - 1 )
2131
+ batch_size_padding = batch_size - previous_hidden_states .shape [
2132
+ 0 ]
2133
+ if batch_size_padding > 0 :
2134
+ dummy_previous_hidden_states = torch .zeros (
2135
+ batch_size_padding ,
2136
+ * previous_hidden_states .shape [1 :],
2137
+ dtype = previous_hidden_states .dtype ,
2138
+ device = previous_hidden_states .device )
2139
+ previous_hidden_states = torch .cat (
2140
+ [previous_hidden_states , dummy_previous_hidden_states ],
2141
+ dim = 0 )
2142
+ execute_model_kwargs .update (
2143
+ {"previous_hidden_states" : previous_hidden_states })
2082
2144
if htorch .utils .internal .is_lazy ():
2083
2145
execute_model_kwargs .update (
2084
2146
{"bypass_hpu_graphs" : not use_graphs })
2085
2147
2086
2148
htorch .core .mark_step ()
2087
2149
if self .is_driver_worker :
2088
2150
model_event_name = ("model_"
2151
+ f"{ self .model_type } _"
2089
2152
f"{ 'prompt' if is_prompt else 'decode' } _"
2090
2153
f"bs{ batch_size } _"
2091
2154
f"seq{ seq_len } _"
@@ -2140,6 +2203,7 @@ def try_revert_dummy_output_tokens():
2140
2203
with self .profiler .record_event (
2141
2204
'internal' ,
2142
2205
('compute_logits_'
2206
+ f"{ self .model_type } _"
2143
2207
f'{ "prompt" if is_prompt else "decode" } _bs'
2144
2208
f'{ batch_size } _'
2145
2209
f'seq{ seq_len } ' )):
@@ -2157,6 +2221,7 @@ def try_revert_dummy_output_tokens():
2157
2221
# Sample the next token.
2158
2222
with self .profiler .record_event (
2159
2223
'internal' , ('sample_'
2224
+ f"{ self .model_type } _"
2160
2225
f'{ "prompt" if is_prompt else "decode" } _'
2161
2226
f'bs{ batch_size } _'
2162
2227
f'seq{ seq_len } ' )):
@@ -2241,6 +2306,18 @@ def try_revert_dummy_output_tokens():
2241
2306
is_prompt = is_prompt )
2242
2307
self .profiler .record_counter (self .event_start , counters )
2243
2308
if num_steps == 1 :
2309
+ if self .vllm_config .speculative_config is not None :
2310
+ output .sampled_token_ids = output .sampled_token_ids [:
2311
+ real_batch_size ]
2312
+ output .sampled_token_probs = output .sampled_token_probs [:
2313
+ real_batch_size ]
2314
+ output .logprobs = output .logprobs [:real_batch_size ]
2315
+ if self .return_hidden_states :
2316
+ # we only need to pass hidden states of most recent token
2317
+ hidden_states = hidden_states [:real_batch_size ]
2318
+ if model_input .is_prompt :
2319
+ output .prefill_hidden_states = hidden_states
2320
+ output .hidden_states = hidden_states
2244
2321
return [output ] if self .is_driver_worker else []
2245
2322
else :
2246
2323
return []
0 commit comments