Skip to content

[INTEL_HPU][v0] Enable spec decode on HPU #17014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
# so we're using a workaround. Remove this when fixed in
# HPU PT bridge.
padded_weight = torch.cat([
loaded_weight,
torch.zeros(param.shape[0] - loaded_weight.shape[0],
*loaded_weight.shape[1:])
])
if param.shape[0] > loaded_weight.shape[0]:
padded_weight = torch.cat([
loaded_weight,
torch.zeros(param.shape[0] - loaded_weight.shape[0],
*loaded_weight.shape[1:])
])
else:
padded_weight = loaded_weight
param.data.copy_(padded_weight)
else:
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
Expand Down
12 changes: 7 additions & 5 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config.worker_cls = \
"vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker"

if vllm_config.speculative_config is not None:
raise NotImplementedError(
"Speculative decoding is not implemented for HPU")

if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
if vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.hpu_worker.HPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"

# NOTE(kzawora): default block size for Gaudi should be 128
# smaller sizes still work, but very inefficiently
Expand Down
16 changes: 13 additions & 3 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform

try:
try:
Expand All @@ -15,9 +16,10 @@
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except (ModuleNotFoundError, ImportError) as err:
raise RuntimeError(
"Draft model speculative decoding currently only supports "
"CUDA and ROCm flash attention backend.") from err
if current_platform.is_cuda_alike():
raise RuntimeError(
"Draft model speculative decoding currently only supports "
"CUDA and ROCm flash attention backend.") from err

from vllm.logger import init_logger
from vllm.multimodal import MultiModalKwargs
Expand All @@ -36,6 +38,14 @@
allow_gpu_advance_step = True


class GeneralTP1DraftModelRunner(ModelRunnerWrapperBase):

def __init__(self, model_runner: ModelRunnerBase):
super().__init__(model_runner)

self.indices_of_seq_with_bonus_tokens = None


class TP1DraftModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
Expand Down
2 changes: 2 additions & 0 deletions vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def maybe_collect_rejsample_metrics(
# Skip for any platform that doesn't have device Event
if current_platform.Event is None:
return None
if not current_platform.is_cuda_alike():
return None

# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
Expand Down
5 changes: 5 additions & 0 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
else:
from vllm.spec_decode.draft_model_runner import GeneralTP1DraftModelRunner

from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
Expand Down Expand Up @@ -186,6 +188,9 @@ def create_worker(
if current_platform.is_cuda_alike():
draft_worker_kwargs[
"model_runner_cls"] = TP1DraftModelRunner
else:
draft_worker_kwargs[
"model_runner_cls"] = GeneralTP1DraftModelRunner
else:
if draft_model_config.hf_config.model_type == "eagle":
raise NotImplementedError(
Expand Down
87 changes: 82 additions & 5 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,9 @@ def forward(self, *args, **kwargs):
selected_token_indices)
return hidden_states

def __getattr__(self, attr: str):
return getattr(self.model, attr)

def compute_logits(self, *args, **kwargs):
return self.model.compute_logits(*args, **kwargs)

Expand Down Expand Up @@ -543,6 +546,7 @@ class ModelInputForHPU(ModelRunnerInputBase):
async_callback: Optional[Callable] = None
is_first_multi_step: bool = True
is_last_step: bool = True
previous_hidden_states: Optional[torch.Tensor] = None

def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
Expand Down Expand Up @@ -645,13 +649,17 @@ def __init__(
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = self.cache_config.cache_dtype

num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
needs_attn_backend = (num_attn_heads != 0
or self.model_config.is_attention_free)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
) if needs_attn_backend else None

# Lazy initialization
self.lora_manager: LRUCacheWorkerLoRAManager = None
Expand All @@ -666,13 +674,29 @@ def __init__(
self.bucketing_global_state = HPUBucketingGlobalState()
self._setup_buckets()
self._set_gc_threshold()
if vllm_config.speculative_config is not None \
and vllm_config.scheduler_config.num_scheduler_steps != 1:
# Speculative decoding is not supported with multi-step scheduling
raise ValueError(
"Speculative decoding is not supported with multi-step "
"scheduling. Please set num_scheduler_steps to 1.")
self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
if vllm_config.speculative_config is not None \
and self.use_contiguous_pa:
logger.warning("Speculative decoding is not supported with "
"contiguous PA, set VLLM_CONTIGUOUS_PA to false")
self.use_contiguous_pa = False
self.model_type = self.model_config.hf_config.model_type
if self.model_type in ("medusa", "mlp_speculator", "eagle"):
self.skip_warmup = True

# For multi-step scheduling
self.cached_step_outputs: List[torch.Tensor] = []
# For delayed sampling
self.cached_step_inputs: List[
ModelInputForHPUWithSamplingMetadata] = []
self.spec_decode_enabled = \
self.vllm_config.speculative_config is not None

def _set_gc_threshold(self) -> None:
# Read https://docs.python.org/3/library/gc.html#gc.set_threshold
Expand Down Expand Up @@ -1496,10 +1520,30 @@ def warmup_scenario(self,
profiler.start()
for _ in range(times):
inputs = self.prepare_model_input(seqs)
# in case spec decode, prepare dummy previous_hidden_states
additional_inputs = {}
if self.model_type in ("medusa", "mlp_speculator", "eagle",
"deepseek_mtp"):
input_tokens = inputs.input_tokens
assert input_tokens is not None
bs = input_tokens.shape[0]
hidden_size = self.model_config.get_hidden_size()

previous_hidden_states = torch.zeros(
(bs, hidden_size),
device=input_tokens.device,
dtype=self.model_config.dtype)
additional_inputs = {
"previous_hidden_states": previous_hidden_states
}

is_single_step = \
self.vllm_config.scheduler_config.num_scheduler_steps == 1
if is_prompt or is_single_step:
self.execute_model(inputs, None, warmup_mode=True)
self.execute_model(inputs,
None,
warmup_mode=True,
**additional_inputs)
else: # decode with multi-step
inputs = dataclasses.replace(inputs,
is_first_multi_step=True,
Expand Down Expand Up @@ -2055,11 +2099,14 @@ def execute_model(
num_steps: int = 1,
warmup_mode=False,
seqs=None,
**kwargs,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING
use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode
assert not (use_delayed_sampling and num_steps != 1), \
'Delayed sampling is not compatible with MSS!'
assert not (use_delayed_sampling and self.spec_decode_enabled), \
'Delayed sampling is not compatible with speculative decoding!'
assert model_input.input_tokens is not None
if use_delayed_sampling and not model_input.is_prompt and \
self.is_driver_worker:
Expand Down Expand Up @@ -2087,6 +2134,7 @@ def execute_model(
0, target_indices, self.cached_step_outputs[i])
htorch.core.mark_step()

previous_hidden_states = kwargs.get('previous_hidden_states')
if not model_input.is_first_multi_step:
if not model_input.is_last_step:
# not first or last multi-step
Expand Down Expand Up @@ -2150,13 +2198,32 @@ def execute_model(
"virtual_engine": model_input.virtual_engine,
**(model_input.multi_modal_kwargs or {}),
}
if previous_hidden_states is not None:
# HPU will pad up to block_size,
# pad previous_hidden_states as well
previous_hidden_states = previous_hidden_states.unsqueeze(
1).expand(-1, input_tokens.shape[-1], -1)
batch_size_padding = batch_size - previous_hidden_states.shape[
0]
if batch_size_padding > 0:
dummy_previous_hidden_states = torch.zeros(
batch_size_padding,
*previous_hidden_states.shape[1:],
dtype=previous_hidden_states.dtype,
device=previous_hidden_states.device)
previous_hidden_states = torch.cat(
[previous_hidden_states, dummy_previous_hidden_states],
dim=0)
execute_model_kwargs.update(
{"previous_hidden_states": previous_hidden_states})
if htorch.utils.internal.is_lazy():
execute_model_kwargs.update(
{"bypass_hpu_graphs": not use_graphs})

htorch.core.mark_step()
if self.is_driver_worker:
model_event_name = ("model_"
f"{self.model_type}_"
f"{'prompt' if is_prompt else 'decode'}_"
f"bs{batch_size}_"
f"seq{seq_len}_"
Expand Down Expand Up @@ -2211,6 +2278,7 @@ def try_revert_dummy_output_tokens():
with self.profiler.record_event(
'internal',
('compute_logits_'
f"{self.model_type}_"
f'{"prompt" if is_prompt else "decode"}_bs'
f'{batch_size}_'
f'seq{seq_len}')):
Expand All @@ -2228,6 +2296,7 @@ def try_revert_dummy_output_tokens():

with self.profiler.record_event(
'internal', ('sample_'
f"{self.model_type}_"
f'{"prompt" if is_prompt else "decode"}_'
f'bs{batch_size}_'
f'seq{seq_len}')):
Expand Down Expand Up @@ -2319,9 +2388,18 @@ def try_revert_dummy_output_tokens():
is_prompt=is_prompt)
self.profiler.record_counter(self.event_start, counters)
if num_steps == 1:
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
if self.spec_decode_enabled and isinstance(
output, SamplerOutput):
output.sampled_token_ids = output.sampled_token_ids[:
real_batch_size]
output.sampled_token_probs = output.sampled_token_probs[:
real_batch_size]
output.logprobs = output.logprobs[:real_batch_size]
if self.return_hidden_states and isinstance(
output, SamplerOutput):
assert model_input.sampling_metadata is not None
# we only need to pass hidden states of most recent token
hidden_states = hidden_states[:real_batch_size]
if model_input.is_prompt:
output.prefill_hidden_states = hidden_states
output.hidden_states = hidden_states
Expand All @@ -2330,7 +2408,6 @@ def try_revert_dummy_output_tokens():
return [fake_output]
else:
return []

return [output] if self.is_driver_worker else []
else:
return []
Expand Down
16 changes: 15 additions & 1 deletion vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,22 @@ def __init__(
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()

speculative_config = self.speculative_config
model_config = self.model_config
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.hf_config.model_type ==
model_config.hf_config.model_type) \
or (speculative_config.draft_model_config.hf_config.model_type
not in ("medusa", "mlp_speculator", "eagle")) \
else {"return_hidden_states": True}

self.model_runner: HPUModelRunner = HPUModelRunner(
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
vllm_config=vllm_config,
is_driver_worker=is_driver_worker,
**speculative_args)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(
self.model_runner) # type: ignore
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: List[HPUCacheEngine]
Expand Down
11 changes: 11 additions & 0 deletions vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,17 @@ def __init__(
def __getattr__(self, attr):
return getattr(self.model_runner, attr)

def __setattr__(self, name, value):
"""
Ensure that setting the 'model_runner' attribute
does not delegate to model_runner
"""

if name == "model_runner":
object.__setattr__(self, name, value)
else:
setattr(self.model_runner, name, value)


class InputProcessingError(Exception):
"""This exception is raised when an error occurs preparing the inputs for
Expand Down