Skip to content

Commit

Permalink
Remove hard-dependencies of Speculative decode to CUDA workers (vllm-…
Browse files Browse the repository at this point in the history
…project#10587)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
  • Loading branch information
xuechendi authored Nov 27, 2024
1 parent 2f0a0a1 commit 0a71900
Show file tree
Hide file tree
Showing 19 changed files with 219 additions and 77 deletions.
4 changes: 2 additions & 2 deletions tests/spec_decode/test_spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,8 @@ def test_init_device(acceptance_sampler_method: str):

target_worker.init_device.assert_called_once()

metrics_collector.init_gpu_tensors.assert_called_once()
spec_decode_sampler.init_gpu_tensors.assert_called_once()
metrics_collector.init_tensors.assert_called_once()
spec_decode_sampler.init_tensors.assert_called_once()


@pytest.mark.parametrize("acceptance_sampler_method",
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ class ParallelConfig:
# the full name of the worker class to use. If "auto", the worker class
# will be determined based on the platform.
worker_cls: str = "auto"
sd_worker_cls: str = "auto"

world_size: int = field(init=False)

Expand Down
17 changes: 16 additions & 1 deletion vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ def init_gpu_tensors(self, device: Union[int, str]) -> None:
dtype=torch.long,
device=device)

def init_tensors(self,
device: Union[int, str],
device_type: Union[torch.device, str] = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device_type, torch.device):
device_type = device_type.type
if isinstance(device, int):
device = f"{device_type}:{device}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)

@property
def probs_dtype(self):
return torch.float32
Expand Down Expand Up @@ -77,7 +92,7 @@ def _create_output(
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze()
bonus_token_ids = bonus_token_ids.squeeze(-1)
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
Expand Down
8 changes: 7 additions & 1 deletion vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend = "mp"
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
if vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.cpu_worker.CPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
4 changes: 3 additions & 1 deletion vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
elif vllm_config.speculative_config:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config.sd_worker_cls = \
"vllm.worker.worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"

Expand Down Expand Up @@ -236,4 +238,4 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
if not isinstance(pynvml, _MockModule):
CudaPlatform.log_warnings()
except ModuleNotFoundError:
CudaPlatform.log_warnings()
CudaPlatform.log_warnings()
24 changes: 12 additions & 12 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from vllm.logger import init_logger
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerWrapperBase)

logger = init_logger(__name__)

Expand All @@ -33,7 +34,7 @@
allow_gpu_advance_step = True


class TP1DraftModelRunner(ModelRunner):
class TP1DraftModelRunner(ModelRunnerWrapperBase):
"""Specialized model runner for speculative decoding draft model.
Since the draft model always execute k forward passes consecutively to
generate k speculative tokens in a single speculative decoding step,
Expand All @@ -46,13 +47,14 @@ class TP1DraftModelRunner(ModelRunner):
any broadcasting inside execute_model).
"""

def __init__(self, *args, **kwargs):
if kwargs.get("return_hidden_states"):
def __init__(self, model_runner: ModelRunnerBase):
if hasattr(
model_runner,
"return_hidden_states") and model_runner.return_hidden_states:
raise ValueError(
"return_hidden_states is not supported for TP1DraftModelRunner."
)

super().__init__(*args, **kwargs)
super().__init__(model_runner)

self.indices_of_seq_with_bonus_tokens = None

Expand All @@ -73,10 +75,8 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs,
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple

def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata,
last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata:
def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
last_output: SamplerOutput) -> ModelRunnerInputBase:
# Currently, we expect "decode mode" only
assert not model_input.is_prompt

Expand Down Expand Up @@ -168,7 +168,7 @@ def set_indices_of_seq_with_bonus_tokens(self,
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
model_input: ModelRunnerInputBase,
kv_caches: List[torch.Tensor],
previous_hidden_states: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
Expand Down
8 changes: 5 additions & 3 deletions vllm/spec_decode/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Set
from typing import Optional, Set, Union

import torch

Expand Down Expand Up @@ -75,9 +75,11 @@ def get_spec_proposals(

class SpeculativeScorer(ABC):

def __init__(self, scorer_worker: WorkerBase, device: str,
vocab_size: int):
def __init__(self, scorer_worker: WorkerBase,
device: Union[torch.device, str], vocab_size: int):
self._scorer_worker = scorer_worker
if isinstance(device, torch.device):
device = device.type
self._device = device
self._vocab_size = vocab_size

Expand Down
9 changes: 5 additions & 4 deletions vllm/spec_decode/medusa_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,22 @@
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
from vllm.worker.worker_base import WorkerWrapperBase


class MedusaWorker(NonLLMProposerWorkerBase, Worker):
class MedusaWorker(NonLLMProposerWorkerBase, WorkerWrapperBase):
"""Worker for Medusa.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)

# Lazy initialization list.
self._proposer: Top1Proposer

def init_device(self):
super().init_device()
self.worker.init_device()

self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
Expand Down
15 changes: 14 additions & 1 deletion vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import time
from typing import Callable, Optional
from typing import Callable, Optional, Union

import msgspec
import torch

from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available


Expand Down Expand Up @@ -81,8 +82,20 @@ def init_gpu_tensors(self, rank: int) -> None:
self._rank = rank
self._copy_stream = torch.cuda.Stream()

def init_tensors(self,
rank: int,
device_type: Union[torch.device, str] = 'cuda') -> None:
self._rank = rank
if isinstance(device_type, torch.device):
device_type = device_type.type
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()

def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
if not current_platform.is_cuda_alike():
return None

# If a copy was initiated in the previous call, collect and return.
if self._in_flight_copy is not None:
Expand Down
31 changes: 24 additions & 7 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,21 @@
import torch

from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

if current_platform.is_cuda_alike():
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner

from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
from vllm.worker.worker_base import WorkerWrapperBase


class MultiStepWorker(Worker, ProposerWorkerBase):
class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
"""The MultiStepWorker is equivalent to a Worker except that it allows
multiple forward passes in a single call, assuming the scheduler has
allocated enough space to store the additional KV. This reduces overhead
Expand All @@ -28,13 +32,14 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
super().__init__(kwargs.get("vllm_config"))
self.init_worker(*args, **kwargs)

# Lazy initialization list.
self._proposer: SpeculativeProposer

def init_device(self) -> None:
super().init_device()
self.worker.init_device()

self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
Expand All @@ -51,6 +56,18 @@ def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)

def determine_num_available_blocks(self) -> Tuple[int, int]:
return self.worker.determine_num_available_blocks()

def get_cache_block_size_bytes(self) -> int:
return self.worker.get_cache_block_size_bytes()

def initialize_cache(self, *args, **kwargs) -> None:
self.worker.initialize_cache(*args, **kwargs)

def execute_model(self, *args, **kwargs) -> List[SamplerOutput]:
return self.worker.execute_model(*args, **kwargs)

@torch.inference_mode()
def sampler_output(
self,
Expand All @@ -75,7 +92,7 @@ def sampler_output(

# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if isinstance(
if current_platform.is_cuda_alike() and isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
Expand All @@ -92,7 +109,7 @@ def sampler_output(
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model(
model_output: List[SamplerOutput] = self.worker.execute_model(
execute_model_req=expanded_request)
assert (len(model_output) == 1
), "composing multistep workers not supported"
Expand Down
3 changes: 2 additions & 1 deletion vllm/spec_decode/ngram_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, *args, **kwargs):
# Get local_rank/vocab_size from kwargs attribute
self.local_rank = kwargs["local_rank"]
self.vocab_size = kwargs["vllm_config"].model_config.get_vocab_size()
self.device_type = kwargs.get("device_type", "cuda")

# Lazy initialization list.
self._proposer: Top1Proposer
Expand All @@ -34,7 +35,7 @@ def set_ngram_window_size(self, ngram_prompt_lookup_min: int,
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min

def init_device(self):
self.device = torch.device(f"cuda:{self.local_rank}")
self.device = torch.device(f"{self.device_type}:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None

# Current NGramWorker only supports Top1Proposer
Expand Down
Loading

0 comments on commit 0a71900

Please sign in to comment.