Skip to content

Commit e3ddd24

Browse files
committed
Enable spec decode with HPU
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
1 parent 471fe65 commit e3ddd24

File tree

8 files changed

+142
-18
lines changed

8 files changed

+142
-18
lines changed

vllm/model_executor/layers/vocab_parallel_embedding.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
390390
# FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
391391
# so we're using a workaround. Remove this when fixed in
392392
# HPU PT bridge.
393-
padded_weight = torch.cat([
394-
loaded_weight,
395-
torch.zeros(param.shape[0] - loaded_weight.shape[0],
396-
*loaded_weight.shape[1:])
397-
])
393+
if param.shape[0] > loaded_weight.shape[0]:
394+
padded_weight = torch.cat([
395+
loaded_weight,
396+
torch.zeros(param.shape[0] - loaded_weight.shape[0],
397+
*loaded_weight.shape[1:])
398+
])
399+
else:
400+
padded_weight = loaded_weight
398401
param.data.copy_(padded_weight)
399402
else:
400403
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)

vllm/platforms/hpu.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
5151
parallel_config.worker_cls = \
5252
"vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker"
5353

54-
if vllm_config.speculative_config is not None:
55-
raise NotImplementedError(
56-
"Speculative decoding is not implemented for HPU")
57-
5854
if parallel_config.worker_cls == "auto":
59-
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
55+
if vllm_config.speculative_config:
56+
parallel_config.worker_cls = \
57+
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
58+
parallel_config.sd_worker_cls = \
59+
"vllm.worker.hpu_worker.HPUWorker"
60+
else:
61+
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
6062

6163
# NOTE(kzawora): default block size for Gaudi should be 128
6264
# smaller sizes still work, but very inefficiently

vllm/spec_decode/draft_model_runner.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from vllm.forward_context import set_forward_context
88
from vllm.model_executor.layers.sampler import SamplerOutput
9+
from vllm.platforms import current_platform
910

1011
try:
1112
try:
@@ -15,9 +16,10 @@
1516
from vllm.attention.backends.rocm_flash_attn import (
1617
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
1718
except (ModuleNotFoundError, ImportError) as err:
18-
raise RuntimeError(
19-
"Draft model speculative decoding currently only supports "
20-
"CUDA and ROCm flash attention backend.") from err
19+
if current_platform.is_cuda_alike():
20+
raise RuntimeError(
21+
"Draft model speculative decoding currently only supports "
22+
"CUDA and ROCm flash attention backend.") from err
2123

2224
from vllm.logger import init_logger
2325
from vllm.multimodal import MultiModalKwargs
@@ -36,6 +38,14 @@
3638
allow_gpu_advance_step = True
3739

3840

41+
class GeneralTP1DraftModelRunner(ModelRunnerWrapperBase):
42+
43+
def __init__(self, model_runner: ModelRunnerBase):
44+
super().__init__(model_runner)
45+
46+
self.indices_of_seq_with_bonus_tokens = None
47+
48+
3949
class TP1DraftModelRunner(ModelRunnerWrapperBase):
4050
"""Specialized model runner for speculative decoding draft model.
4151
Since the draft model always execute k forward passes consecutively to

vllm/spec_decode/metrics.py

+2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def maybe_collect_rejsample_metrics(
9999
# Skip for any platform that doesn't have device Event
100100
if current_platform.Event is None:
101101
return None
102+
if not current_platform.is_cuda_alike():
103+
return None
102104

103105
# If a copy was initiated in the previous call, collect and return.
104106
if self._in_flight_copy is not None:

vllm/spec_decode/spec_decode_worker.py

+5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
if current_platform.is_cuda_alike():
3131
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
32+
else:
33+
from vllm.spec_decode.draft_model_runner import GeneralTP1DraftModelRunner
3234

3335
from vllm.spec_decode.interfaces import (SpeculativeProposals,
3436
SpeculativeScorer, SpeculativeScores)
@@ -186,6 +188,9 @@ def create_worker(
186188
if current_platform.is_cuda_alike():
187189
draft_worker_kwargs[
188190
"model_runner_cls"] = TP1DraftModelRunner
191+
else:
192+
draft_worker_kwargs[
193+
"model_runner_cls"] = GeneralTP1DraftModelRunner
189194
else:
190195
if draft_model_config.hf_config.model_type == "eagle":
191196
raise NotImplementedError(

vllm/worker/hpu_model_runner.py

+79-2
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,9 @@ def forward(self, *args, **kwargs):
448448
selected_token_indices)
449449
return hidden_states
450450

451+
def __getattr__(self, attr: str):
452+
return getattr(self.model, attr)
453+
451454
def compute_logits(self, *args, **kwargs):
452455
return self.model.compute_logits(*args, **kwargs)
453456

@@ -541,6 +544,7 @@ class ModelInputForHPU(ModelRunnerInputBase):
541544
async_callback: Optional[Callable] = None
542545
is_first_multi_step: bool = True
543546
is_last_step: bool = True
547+
previous_hidden_states: Optional[torch.Tensor] = None
544548

545549
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
546550
tensor_dict = {
@@ -643,13 +647,17 @@ def __init__(
643647
self.pin_memory = is_pin_memory_available()
644648
self.kv_cache_dtype = self.cache_config.cache_dtype
645649

650+
num_attn_heads = self.model_config.get_num_attention_heads(
651+
self.parallel_config)
652+
needs_attn_backend = (num_attn_heads != 0
653+
or self.model_config.is_attention_free)
646654
self.attn_backend = get_attn_backend(
647655
self.model_config.get_head_size(),
648656
self.model_config.dtype,
649657
self.kv_cache_dtype,
650658
self.block_size,
651659
self.model_config.is_attention_free,
652-
)
660+
) if needs_attn_backend else None
653661

654662
# Lazy initialization
655663
self.lora_manager: LRUCacheWorkerLoRAManager = None
@@ -664,7 +672,21 @@ def __init__(
664672
self.bucketing_global_state = HPUBucketingGlobalState()
665673
self._setup_buckets()
666674
self._set_gc_threshold()
675+
if vllm_config.speculative_config is not None \
676+
and vllm_config.scheduler_config.num_scheduler_steps != 1:
677+
# Speculative decoding is not supported with multi-step scheduling
678+
raise ValueError(
679+
"Speculative decoding is not supported with multi-step "
680+
"scheduling. Please set num_scheduler_steps to 1.")
667681
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
682+
if vllm_config.speculative_config is not None \
683+
and self.use_contiguous_pa:
684+
logger.warning("Speculative decoding is not supported with "
685+
"contiguous PA, set VLLM_CONTIGUOUS_PA to false")
686+
self.use_contiguous_pa = False
687+
self.model_type = self.model_config.hf_config.model_type
688+
if self.model_type in ("medusa", "mlp_speculator", "eagle"):
689+
self.skip_warmup = True
668690

669691
# For multi-step scheduling
670692
self.cached_step_outputs: List[torch.Tensor] = []
@@ -1485,10 +1507,30 @@ def warmup_scenario(self,
14851507
profiler.start()
14861508
for _ in range(times):
14871509
inputs = self.prepare_model_input(seqs)
1510+
# in case spec decode, prepare dummy previous_hidden_states
1511+
additional_inputs = {}
1512+
if self.model_type in ("medusa", "mlp_speculator", "eagle",
1513+
"deepseek_mtp"):
1514+
input_tokens = inputs.input_tokens
1515+
assert input_tokens is not None
1516+
bs = input_tokens.shape[0]
1517+
hidden_size = self.model_config.get_hidden_size()
1518+
1519+
previous_hidden_states = torch.zeros(
1520+
(bs, hidden_size),
1521+
device=input_tokens.device,
1522+
dtype=self.model_config.dtype)
1523+
additional_inputs = {
1524+
"previous_hidden_states": previous_hidden_states
1525+
}
1526+
14881527
is_single_step = \
14891528
self.vllm_config.scheduler_config.num_scheduler_steps == 1
14901529
if is_prompt or is_single_step:
1491-
self.execute_model(inputs, None, warmup_mode=True)
1530+
self.execute_model(inputs,
1531+
None,
1532+
warmup_mode=True,
1533+
**additional_inputs)
14921534
else: # decode with multi-step
14931535
inputs = dataclasses.replace(inputs,
14941536
is_first_multi_step=True,
@@ -2029,7 +2071,9 @@ def execute_model(
20292071
num_steps: int = 1,
20302072
warmup_mode=False,
20312073
seqs=None,
2074+
**kwargs,
20322075
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
2076+
previous_hidden_states = kwargs.get('previous_hidden_states')
20332077
if not model_input.is_first_multi_step:
20342078
if not model_input.is_last_step:
20352079
# not first or last multi-step
@@ -2079,13 +2123,32 @@ def execute_model(
20792123
"virtual_engine": model_input.virtual_engine,
20802124
**(model_input.multi_modal_kwargs or {}),
20812125
}
2126+
if previous_hidden_states is not None:
2127+
# HPU will pad up to block_size,
2128+
# pad previous_hidden_states as well
2129+
previous_hidden_states = previous_hidden_states.unsqueeze(
2130+
1).expand(-1, input_tokens.shape[-1], -1)
2131+
batch_size_padding = batch_size - previous_hidden_states.shape[
2132+
0]
2133+
if batch_size_padding > 0:
2134+
dummy_previous_hidden_states = torch.zeros(
2135+
batch_size_padding,
2136+
*previous_hidden_states.shape[1:],
2137+
dtype=previous_hidden_states.dtype,
2138+
device=previous_hidden_states.device)
2139+
previous_hidden_states = torch.cat(
2140+
[previous_hidden_states, dummy_previous_hidden_states],
2141+
dim=0)
2142+
execute_model_kwargs.update(
2143+
{"previous_hidden_states": previous_hidden_states})
20822144
if htorch.utils.internal.is_lazy():
20832145
execute_model_kwargs.update(
20842146
{"bypass_hpu_graphs": not use_graphs})
20852147

20862148
htorch.core.mark_step()
20872149
if self.is_driver_worker:
20882150
model_event_name = ("model_"
2151+
f"{self.model_type}_"
20892152
f"{'prompt' if is_prompt else 'decode'}_"
20902153
f"bs{batch_size}_"
20912154
f"seq{seq_len}_"
@@ -2140,6 +2203,7 @@ def try_revert_dummy_output_tokens():
21402203
with self.profiler.record_event(
21412204
'internal',
21422205
('compute_logits_'
2206+
f"{self.model_type}_"
21432207
f'{"prompt" if is_prompt else "decode"}_bs'
21442208
f'{batch_size}_'
21452209
f'seq{seq_len}')):
@@ -2157,6 +2221,7 @@ def try_revert_dummy_output_tokens():
21572221
# Sample the next token.
21582222
with self.profiler.record_event(
21592223
'internal', ('sample_'
2224+
f"{self.model_type}_"
21602225
f'{"prompt" if is_prompt else "decode"}_'
21612226
f'bs{batch_size}_'
21622227
f'seq{seq_len}')):
@@ -2241,6 +2306,18 @@ def try_revert_dummy_output_tokens():
22412306
is_prompt=is_prompt)
22422307
self.profiler.record_counter(self.event_start, counters)
22432308
if num_steps == 1:
2309+
if self.vllm_config.speculative_config is not None:
2310+
output.sampled_token_ids = output.sampled_token_ids[:
2311+
real_batch_size]
2312+
output.sampled_token_probs = output.sampled_token_probs[:
2313+
real_batch_size]
2314+
output.logprobs = output.logprobs[:real_batch_size]
2315+
if self.return_hidden_states:
2316+
# we only need to pass hidden states of most recent token
2317+
hidden_states = hidden_states[:real_batch_size]
2318+
if model_input.is_prompt:
2319+
output.prefill_hidden_states = hidden_states
2320+
output.hidden_states = hidden_states
22442321
return [output] if self.is_driver_worker else []
22452322
else:
22462323
return []

vllm/worker/hpu_worker.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from vllm.sequence import ExecuteModelRequest
2727
from vllm.utils import bind_kv_cache
2828
from vllm.worker.cache_engine import CacheEngine
29-
from vllm.worker.hpu_model_runner import HPUModelRunner
29+
from vllm.worker.hpu_model_runner import HPUModelRunner, HPUModelRunnerBase
3030
from vllm.worker.model_runner_base import ModelRunnerBase
3131
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
3232
WorkerInput)
@@ -65,8 +65,22 @@ def __init__(
6565
from vllm.utils import init_cached_hf_modules
6666
init_cached_hf_modules()
6767

68-
self.model_runner: HPUModelRunner = HPUModelRunner(
69-
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
68+
speculative_config = self.speculative_config
69+
model_config = self.model_config
70+
speculative_args = {} if speculative_config is None \
71+
or (speculative_config.draft_model_config.hf_config.model_type ==
72+
model_config.hf_config.model_type) \
73+
or (speculative_config.draft_model_config.hf_config.model_type
74+
not in ("medusa", "mlp_speculator", "eagle")) \
75+
else {"return_hidden_states": True}
76+
77+
ModelRunnerClass: Type[HPUModelRunnerBase] = HPUModelRunner
78+
self.model_runner: HPUModelRunner = ModelRunnerClass(
79+
vllm_config=vllm_config,
80+
is_driver_worker=is_driver_worker,
81+
**speculative_args)
82+
if model_runner_cls is not None:
83+
self.model_runner = model_runner_cls(self.model_runner)
7084
# Uninitialized cache engine. Will be initialized by
7185
# initialize_cache.
7286
self.cache_engine: List[HPUCacheEngine]

vllm/worker/model_runner_base.py

+11
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,17 @@ def __init__(
262262
def __getattr__(self, attr):
263263
return getattr(self.model_runner, attr)
264264

265+
def __setattr__(self, name, value):
266+
"""
267+
Ensure that setting the 'model_runner' attribute
268+
does not delegate to model_runner
269+
"""
270+
271+
if name == "model_runner":
272+
object.__setattr__(self, name, value)
273+
else:
274+
setattr(self.model_runner, name, value)
275+
265276

266277
class InputProcessingError(Exception):
267278
"""This exception is raised when an error occurs preparing the inputs for

0 commit comments

Comments
 (0)