Skip to content

Commit 4f709d6

Browse files
committed
Enable spec decode with HPU
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
1 parent 8e630d6 commit 4f709d6

File tree

8 files changed

+143
-19
lines changed

8 files changed

+143
-19
lines changed

vllm/model_executor/layers/vocab_parallel_embedding.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
390390
# FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
391391
# so we're using a workaround. Remove this when fixed in
392392
# HPU PT bridge.
393-
padded_weight = torch.cat([
394-
loaded_weight,
395-
torch.zeros(param.shape[0] - loaded_weight.shape[0],
396-
*loaded_weight.shape[1:])
397-
])
393+
if param.shape[0] > loaded_weight.shape[0]:
394+
padded_weight = torch.cat([
395+
loaded_weight,
396+
torch.zeros(param.shape[0] - loaded_weight.shape[0],
397+
*loaded_weight.shape[1:])
398+
])
399+
else:
400+
padded_weight = loaded_weight
398401
param.data.copy_(padded_weight)
399402
else:
400403
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)

vllm/platforms/hpu.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
5151
parallel_config.worker_cls = \
5252
"vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker"
5353

54-
if vllm_config.speculative_config is not None:
55-
raise NotImplementedError(
56-
"Speculative decoding is not implemented for HPU")
57-
5854
if parallel_config.worker_cls == "auto":
59-
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
55+
if vllm_config.speculative_config:
56+
parallel_config.worker_cls = \
57+
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
58+
parallel_config.sd_worker_cls = \
59+
"vllm.worker.hpu_worker.HPUWorker"
60+
else:
61+
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
6062

6163
# NOTE(kzawora): default block size for Gaudi should be 128
6264
# smaller sizes still work, but very inefficiently

vllm/spec_decode/draft_model_runner.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from vllm.forward_context import set_forward_context
88
from vllm.model_executor.layers.sampler import SamplerOutput
9+
from vllm.platforms import current_platform
910

1011
try:
1112
try:
@@ -15,9 +16,10 @@
1516
from vllm.attention.backends.rocm_flash_attn import (
1617
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
1718
except (ModuleNotFoundError, ImportError) as err:
18-
raise RuntimeError(
19-
"Draft model speculative decoding currently only supports "
20-
"CUDA and ROCm flash attention backend.") from err
19+
if current_platform.is_cuda_alike():
20+
raise RuntimeError(
21+
"Draft model speculative decoding currently only supports "
22+
"CUDA and ROCm flash attention backend.") from err
2123

2224
from vllm.logger import init_logger
2325
from vllm.multimodal import MultiModalKwargs
@@ -36,6 +38,14 @@
3638
allow_gpu_advance_step = True
3739

3840

41+
class GeneralTP1DraftModelRunner(ModelRunnerWrapperBase):
42+
43+
def __init__(self, model_runner: ModelRunnerBase):
44+
super().__init__(model_runner)
45+
46+
self.indices_of_seq_with_bonus_tokens = None
47+
48+
3949
class TP1DraftModelRunner(ModelRunnerWrapperBase):
4050
"""Specialized model runner for speculative decoding draft model.
4151
Since the draft model always execute k forward passes consecutively to

vllm/spec_decode/metrics.py

+2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def maybe_collect_rejsample_metrics(
9999
# Skip for any platform that doesn't have device Event
100100
if current_platform.Event is None:
101101
return None
102+
if not current_platform.is_cuda_alike():
103+
return None
102104

103105
# If a copy was initiated in the previous call, collect and return.
104106
if self._in_flight_copy is not None:

vllm/spec_decode/spec_decode_worker.py

+5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
if current_platform.is_cuda_alike():
3131
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
32+
else:
33+
from vllm.spec_decode.draft_model_runner import GeneralTP1DraftModelRunner
3234

3335
from vllm.spec_decode.interfaces import (SpeculativeProposals,
3436
SpeculativeScorer, SpeculativeScores)
@@ -186,6 +188,9 @@ def create_worker(
186188
if current_platform.is_cuda_alike():
187189
draft_worker_kwargs[
188190
"model_runner_cls"] = TP1DraftModelRunner
191+
else:
192+
draft_worker_kwargs[
193+
"model_runner_cls"] = GeneralTP1DraftModelRunner
189194
else:
190195
if draft_model_config.hf_config.model_type == "eagle":
191196
raise NotImplementedError(

vllm/worker/hpu_model_runner.py

+82-5
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,9 @@ def forward(self, *args, **kwargs):
450450
selected_token_indices)
451451
return hidden_states
452452

453+
def __getattr__(self, attr: str):
454+
return getattr(self.model, attr)
455+
453456
def compute_logits(self, *args, **kwargs):
454457
return self.model.compute_logits(*args, **kwargs)
455458

@@ -543,6 +546,7 @@ class ModelInputForHPU(ModelRunnerInputBase):
543546
async_callback: Optional[Callable] = None
544547
is_first_multi_step: bool = True
545548
is_last_step: bool = True
549+
previous_hidden_states: Optional[torch.Tensor] = None
546550

547551
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
548552
tensor_dict = {
@@ -645,13 +649,17 @@ def __init__(
645649
self.pin_memory = is_pin_memory_available()
646650
self.kv_cache_dtype = self.cache_config.cache_dtype
647651

652+
num_attn_heads = self.model_config.get_num_attention_heads(
653+
self.parallel_config)
654+
needs_attn_backend = (num_attn_heads != 0
655+
or self.model_config.is_attention_free)
648656
self.attn_backend = get_attn_backend(
649657
self.model_config.get_head_size(),
650658
self.model_config.dtype,
651659
self.kv_cache_dtype,
652660
self.block_size,
653661
self.model_config.is_attention_free,
654-
)
662+
) if needs_attn_backend else None
655663

656664
# Lazy initialization
657665
self.lora_manager: LRUCacheWorkerLoRAManager = None
@@ -666,13 +674,29 @@ def __init__(
666674
self.bucketing_global_state = HPUBucketingGlobalState()
667675
self._setup_buckets()
668676
self._set_gc_threshold()
677+
if vllm_config.speculative_config is not None \
678+
and vllm_config.scheduler_config.num_scheduler_steps != 1:
679+
# Speculative decoding is not supported with multi-step scheduling
680+
raise ValueError(
681+
"Speculative decoding is not supported with multi-step "
682+
"scheduling. Please set num_scheduler_steps to 1.")
669683
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
684+
if vllm_config.speculative_config is not None \
685+
and self.use_contiguous_pa:
686+
logger.warning("Speculative decoding is not supported with "
687+
"contiguous PA, set VLLM_CONTIGUOUS_PA to false")
688+
self.use_contiguous_pa = False
689+
self.model_type = self.model_config.hf_config.model_type
690+
if self.model_type in ("medusa", "mlp_speculator", "eagle"):
691+
self.skip_warmup = True
670692

671693
# For multi-step scheduling
672694
self.cached_step_outputs: List[torch.Tensor] = []
673695
# For delayed sampling
674696
self.cached_step_inputs: List[
675697
ModelInputForHPUWithSamplingMetadata] = []
698+
self.spec_decode_enabled = \
699+
self.vllm_config.speculative_config is not None
676700

677701
def _set_gc_threshold(self) -> None:
678702
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
@@ -1496,10 +1520,30 @@ def warmup_scenario(self,
14961520
profiler.start()
14971521
for _ in range(times):
14981522
inputs = self.prepare_model_input(seqs)
1523+
# in case spec decode, prepare dummy previous_hidden_states
1524+
additional_inputs = {}
1525+
if self.model_type in ("medusa", "mlp_speculator", "eagle",
1526+
"deepseek_mtp"):
1527+
input_tokens = inputs.input_tokens
1528+
assert input_tokens is not None
1529+
bs = input_tokens.shape[0]
1530+
hidden_size = self.model_config.get_hidden_size()
1531+
1532+
previous_hidden_states = torch.zeros(
1533+
(bs, hidden_size),
1534+
device=input_tokens.device,
1535+
dtype=self.model_config.dtype)
1536+
additional_inputs = {
1537+
"previous_hidden_states": previous_hidden_states
1538+
}
1539+
14991540
is_single_step = \
15001541
self.vllm_config.scheduler_config.num_scheduler_steps == 1
15011542
if is_prompt or is_single_step:
1502-
self.execute_model(inputs, None, warmup_mode=True)
1543+
self.execute_model(inputs,
1544+
None,
1545+
warmup_mode=True,
1546+
**additional_inputs)
15031547
else: # decode with multi-step
15041548
inputs = dataclasses.replace(inputs,
15051549
is_first_multi_step=True,
@@ -2055,11 +2099,14 @@ def execute_model(
20552099
num_steps: int = 1,
20562100
warmup_mode=False,
20572101
seqs=None,
2102+
**kwargs,
20582103
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
20592104
VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING
20602105
use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
20612106
assert not (use_delayed_sampling and num_steps != 1), \
20622107
'Delayed sampling is not compatible with MSS!'
2108+
assert not (use_delayed_sampling and self.spec_decode_enabled), \
2109+
'Delayed sampling is not compatible with speculative decoding!'
20632110
assert model_input.input_tokens is not None
20642111
if use_delayed_sampling and not model_input.is_prompt and \
20652112
self.is_driver_worker:
@@ -2087,6 +2134,7 @@ def execute_model(
20872134
0, target_indices, self.cached_step_outputs[i])
20882135
htorch.core.mark_step()
20892136

2137+
previous_hidden_states = kwargs.get('previous_hidden_states')
20902138
if not model_input.is_first_multi_step:
20912139
if not model_input.is_last_step:
20922140
# not first or last multi-step
@@ -2150,13 +2198,32 @@ def execute_model(
21502198
"virtual_engine": model_input.virtual_engine,
21512199
**(model_input.multi_modal_kwargs or {}),
21522200
}
2201+
if previous_hidden_states is not None:
2202+
# HPU will pad up to block_size,
2203+
# pad previous_hidden_states as well
2204+
previous_hidden_states = previous_hidden_states.unsqueeze(
2205+
1).expand(-1, input_tokens.shape[-1], -1)
2206+
batch_size_padding = batch_size - previous_hidden_states.shape[
2207+
0]
2208+
if batch_size_padding > 0:
2209+
dummy_previous_hidden_states = torch.zeros(
2210+
batch_size_padding,
2211+
*previous_hidden_states.shape[1:],
2212+
dtype=previous_hidden_states.dtype,
2213+
device=previous_hidden_states.device)
2214+
previous_hidden_states = torch.cat(
2215+
[previous_hidden_states, dummy_previous_hidden_states],
2216+
dim=0)
2217+
execute_model_kwargs.update(
2218+
{"previous_hidden_states": previous_hidden_states})
21532219
if htorch.utils.internal.is_lazy():
21542220
execute_model_kwargs.update(
21552221
{"bypass_hpu_graphs": not use_graphs})
21562222

21572223
htorch.core.mark_step()
21582224
if self.is_driver_worker:
21592225
model_event_name = ("model_"
2226+
f"{self.model_type}_"
21602227
f"{'prompt' if is_prompt else 'decode'}_"
21612228
f"bs{batch_size}_"
21622229
f"seq{seq_len}_"
@@ -2211,6 +2278,7 @@ def try_revert_dummy_output_tokens():
22112278
with self.profiler.record_event(
22122279
'internal',
22132280
('compute_logits_'
2281+
f"{self.model_type}_"
22142282
f'{"prompt" if is_prompt else "decode"}_bs'
22152283
f'{batch_size}_'
22162284
f'seq{seq_len}')):
@@ -2228,6 +2296,7 @@ def try_revert_dummy_output_tokens():
22282296

22292297
with self.profiler.record_event(
22302298
'internal', ('sample_'
2299+
f"{self.model_type}_"
22312300
f'{"prompt" if is_prompt else "decode"}_'
22322301
f'bs{batch_size}_'
22332302
f'seq{seq_len}')):
@@ -2319,9 +2388,18 @@ def try_revert_dummy_output_tokens():
23192388
is_prompt=is_prompt)
23202389
self.profiler.record_counter(self.event_start, counters)
23212390
if num_steps == 1:
2322-
if self.return_hidden_states:
2323-
# we only need to pass hidden states of most recent token
2391+
if self.spec_decode_enabled and isinstance(
2392+
output, SamplerOutput):
2393+
output.sampled_token_ids = output.sampled_token_ids[:
2394+
real_batch_size]
2395+
output.sampled_token_probs = output.sampled_token_probs[:
2396+
real_batch_size]
2397+
output.logprobs = output.logprobs[:real_batch_size]
2398+
if self.return_hidden_states and isinstance(
2399+
output, SamplerOutput):
23242400
assert model_input.sampling_metadata is not None
2401+
# we only need to pass hidden states of most recent token
2402+
hidden_states = hidden_states[:real_batch_size]
23252403
if model_input.is_prompt:
23262404
output.prefill_hidden_states = hidden_states
23272405
output.hidden_states = hidden_states
@@ -2330,7 +2408,6 @@ def try_revert_dummy_output_tokens():
23302408
return [fake_output]
23312409
else:
23322410
return []
2333-
23342411
return [output] if self.is_driver_worker else []
23352412
else:
23362413
return []

vllm/worker/hpu_worker.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,22 @@ def __init__(
6565
from vllm.utils import init_cached_hf_modules
6666
init_cached_hf_modules()
6767

68+
speculative_config = self.speculative_config
69+
model_config = self.model_config
70+
speculative_args = {} if speculative_config is None \
71+
or (speculative_config.draft_model_config.hf_config.model_type ==
72+
model_config.hf_config.model_type) \
73+
or (speculative_config.draft_model_config.hf_config.model_type
74+
not in ("medusa", "mlp_speculator", "eagle")) \
75+
else {"return_hidden_states": True}
76+
6877
self.model_runner: HPUModelRunner = HPUModelRunner(
69-
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
78+
vllm_config=vllm_config,
79+
is_driver_worker=is_driver_worker,
80+
**speculative_args)
81+
if model_runner_cls is not None:
82+
self.model_runner = model_runner_cls(
83+
self.model_runner) # type: ignore
7084
# Uninitialized cache engine. Will be initialized by
7185
# initialize_cache.
7286
self.cache_engine: List[HPUCacheEngine]

vllm/worker/model_runner_base.py

+11
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,17 @@ def __init__(
262262
def __getattr__(self, attr):
263263
return getattr(self.model_runner, attr)
264264

265+
def __setattr__(self, name, value):
266+
"""
267+
Ensure that setting the 'model_runner' attribute
268+
does not delegate to model_runner
269+
"""
270+
271+
if name == "model_runner":
272+
object.__setattr__(self, name, value)
273+
else:
274+
setattr(self.model_runner, name, value)
275+
265276

266277
class InputProcessingError(Exception):
267278
"""This exception is raised when an error occurs preparing the inputs for

0 commit comments

Comments
 (0)