Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding] Support draft model on different tensor-parallel size than target model #5414

Merged
merged 131 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
f5b5f94
tp1 draft worker
wooyeonlee0 Jun 10, 2024
709de21
refactor singlt_tp_worker
wooyeonlee0 Jun 10, 2024
0eacc96
update execute_model logic
wooyeonlee0 Jun 10, 2024
2011ed0
fix
wooyeonlee0 Jun 11, 2024
2e16c4e
DummyProposerWorker
wooyeonlee0 Jun 11, 2024
b412a51
fix
wooyeonlee0 Jun 11, 2024
593ccfa
init only partial workers
wooyeonlee0 Jun 11, 2024
c5d3476
Use multi_step_worker logic
wooyeonlee0 Jun 12, 2024
44e623b
self._patch_tp_group
wooyeonlee0 Jun 12, 2024
98caf17
refactor it to support other draft-tp than 1
wooyeonlee0 Jun 12, 2024
7fc4ff5
spec-tp configuarable
wooyeonlee0 Jun 12, 2024
a96e720
ngram worker support test
wooyeonlee0 Jun 12, 2024
db39576
minor refine
wooyeonlee0 Jun 12, 2024
b2e8595
cleanup
wooyeonlee0 Jun 12, 2024
756442a
return type fix
wooyeonlee0 Jun 12, 2024
32094f1
cleanup
wooyeonlee0 Jun 12, 2024
7890191
cleanup
wooyeonlee0 Jun 12, 2024
53b2ea9
typo
wooyeonlee0 Jun 12, 2024
a29c9c5
verify arg
wooyeonlee0 Jun 12, 2024
52ba09d
remove testing code
wooyeonlee0 Jun 12, 2024
d26ef08
cleanup
wooyeonlee0 Jun 12, 2024
80c4994
rename module
wooyeonlee0 Jun 12, 2024
0f16f3f
cleanup
wooyeonlee0 Jun 12, 2024
140f478
cleanup
wooyeonlee0 Jun 12, 2024
3fd7e91
remove unnecessary methods
wooyeonlee0 Jun 12, 2024
495aa30
fix
wooyeonlee0 Jun 12, 2024
3a5a47f
undo unrelated changes
wooyeonlee0 Jun 12, 2024
07ddbb8
minor fix
wooyeonlee0 Jun 12, 2024
b0a677d
fix ruff errors
wooyeonlee0 Jun 12, 2024
96782a2
Merge branch 'main' into spec-tp1-draft
wooyeonlee0 Jun 12, 2024
9998b9c
typo
wooyeonlee0 Jun 12, 2024
e92ecdc
temporal fix
wooyeonlee0 Jun 12, 2024
b421607
formatting
wooyeonlee0 Jun 12, 2024
386ab9b
isort
wooyeonlee0 Jun 12, 2024
b25f74e
line length
wooyeonlee0 Jun 12, 2024
8b51f08
fix
wooyeonlee0 Jun 13, 2024
d4b283c
Merge remote-tracking branch 'origin' into spec-tp1-draft
wooyeonlee0 Jun 13, 2024
dfc90cb
line length
wooyeonlee0 Jun 13, 2024
9bef5e4
comment
wooyeonlee0 Jun 13, 2024
85d087d
add type hint
wooyeonlee0 Jun 13, 2024
9af36b7
isort
wooyeonlee0 Jun 13, 2024
5a0bf45
add more type hints
wooyeonlee0 Jun 13, 2024
531c9f0
fix
wooyeonlee0 Jun 13, 2024
287da20
test
wooyeonlee0 Jun 13, 2024
08d1b2a
nit
wooyeonlee0 Jun 13, 2024
237c966
fix yapf
wooyeonlee0 Jun 13, 2024
0bb38c2
fix
wooyeonlee0 Jun 13, 2024
c097d6c
fix
wooyeonlee0 Jun 13, 2024
957a325
fix
wooyeonlee0 Jun 13, 2024
3ec8cb5
Merge remote-tracking branch 'origin' into spec-tp1-draft
wooyeonlee0 Jun 14, 2024
8a8a1e4
add comments
wooyeonlee0 Jun 14, 2024
7f06f64
combine smaller_tp_worker logic into multi_step_worker
wooyeonlee0 Jun 14, 2024
1e87579
fix
wooyeonlee0 Jun 14, 2024
abc546c
fix
wooyeonlee0 Jun 14, 2024
7880cb0
add small_tp correctness test
wooyeonlee0 Jun 14, 2024
2ebe6f3
nit
wooyeonlee0 Jun 14, 2024
90d46ee
fix
wooyeonlee0 Jun 14, 2024
7e1426c
refactor. remove log
wooyeonlee0 Jun 14, 2024
ad52d93
remove return
wooyeonlee0 Jun 14, 2024
355475b
fix
wooyeonlee0 Jun 14, 2024
9cfdb5b
fix about context managing
wooyeonlee0 Jun 14, 2024
6a6c5ff
nit
wooyeonlee0 Jun 14, 2024
ddef229
consistent condition. if self._is_dummy:
wooyeonlee0 Jun 14, 2024
965f648
fix ruff errors
wooyeonlee0 Jun 14, 2024
1bb5534
isort
wooyeonlee0 Jun 14, 2024
ea6b8f5
fix yapf
wooyeonlee0 Jun 14, 2024
71977d2
undo ngramworker support
wooyeonlee0 Jun 14, 2024
bc5f77a
add comment
wooyeonlee0 Jun 14, 2024
5655a49
remove smaller_tp_proposer_worker
wooyeonlee0 Jun 14, 2024
eabc16a
ruff
wooyeonlee0 Jun 14, 2024
f748edf
remove ranks arg
wooyeonlee0 Jun 17, 2024
c099c94
Merge remote-tracking branch 'origin' into spec-tp1-draft
wooyeonlee0 Jun 17, 2024
4b74a45
undo
wooyeonlee0 Jun 17, 2024
c9786ad
add dist test
wooyeonlee0 Jun 17, 2024
a42664a
nit
wooyeonlee0 Jun 17, 2024
ac7701a
fix
wooyeonlee0 Jun 17, 2024
eea6a7e
test fix
wooyeonlee0 Jun 17, 2024
a648f5d
yapf fix
wooyeonlee0 Jun 17, 2024
f23ba8c
update comment
wooyeonlee0 Jun 17, 2024
aa9af93
require 2 gpus
wooyeonlee0 Jun 17, 2024
56c8927
restore draft_ranks arg in MultiStepWorker.__init__
wooyeonlee0 Jun 18, 2024
385b4f8
comment
wooyeonlee0 Jun 18, 2024
43f37eb
ruff mypy
wooyeonlee0 Jun 18, 2024
99350e2
isort
wooyeonlee0 Jun 18, 2024
a9f3e23
yapf
wooyeonlee0 Jun 18, 2024
6ba250d
allow None for draft_ranks
wooyeonlee0 Jun 18, 2024
3e78613
spec-tp arg in benchmark_latency
wooyeonlee0 Jun 18, 2024
6532af7
yapf
wooyeonlee0 Jun 18, 2024
6839797
yapf
wooyeonlee0 Jun 18, 2024
aac586b
Merge remote-tracking branch 'origin' into spec-tp1-draft
wooyeonlee0 Jun 19, 2024
98e584d
remove is_dummy check from sampler_output
wooyeonlee0 Jun 19, 2024
2d5e64d
add comment
wooyeonlee0 Jun 20, 2024
ba88bd4
yapf
wooyeonlee0 Jun 20, 2024
46e5274
resolve cade comments
wooyeonlee0 Jun 21, 2024
85f4f25
refactoring patch_tp_group
wooyeonlee0 Jun 21, 2024
c1b5373
cleanup patch_tp_group logic
wooyeonlee0 Jun 21, 2024
4a58617
speculative_draft_tensor_parallel_size
wooyeonlee0 Jun 21, 2024
b09e7be
ruff, yapf
wooyeonlee0 Jun 21, 2024
7168d78
remove world group patch
wooyeonlee0 Jun 21, 2024
fe0bd5b
isort, yapf
wooyeonlee0 Jun 21, 2024
2e0d170
yield fix
wooyeonlee0 Jun 21, 2024
36f8aa5
debugging
wooyeonlee0 Jun 21, 2024
54bf514
log
wooyeonlee0 Jun 21, 2024
bfd7d2f
reintroduce smaller_tp_proposer_worker
wooyeonlee0 Jun 21, 2024
f337428
add lora methods
wooyeonlee0 Jun 21, 2024
4654b9f
missing method
wooyeonlee0 Jun 21, 2024
e39926e
remove world group related logics
wooyeonlee0 Jun 21, 2024
1c6eefd
Always wrapping MultiStepWorker
wooyeonlee0 Jun 21, 2024
f2d2ee5
remove unused logger
wooyeonlee0 Jun 21, 2024
302955c
isort. minor rename
wooyeonlee0 Jun 21, 2024
3d4754e
LoraNotSupported. return type
wooyeonlee0 Jun 21, 2024
620b224
yapf, ruff
wooyeonlee0 Jun 21, 2024
b245d3c
add skip_spec_test
wooyeonlee0 Jun 21, 2024
1e71e98
remove spec-tp 3 case
wooyeonlee0 Jun 21, 2024
a01c00d
spec-draft-tp
wooyeonlee0 Jun 21, 2024
debffc2
_TP_STATE_PATCHED
wooyeonlee0 Jun 24, 2024
39fe67f
remove stale comment
wooyeonlee0 Jun 24, 2024
af1b0be
dist_tp2, dist_tp4 tests
wooyeonlee0 Jun 24, 2024
834c6e0
remove unnecessary overriding methods
wooyeonlee0 Jun 24, 2024
5bc2bc3
comment
wooyeonlee0 Jun 24, 2024
8740369
yapf
wooyeonlee0 Jun 24, 2024
4d82ca1
comment
wooyeonlee0 Jun 24, 2024
7bf831c
undo change in test utils
wooyeonlee0 Jun 24, 2024
3fccc76
remove test_skip_speculation
wooyeonlee0 Jun 24, 2024
e8d0e93
tp4 test only for spec_tp1
wooyeonlee0 Jun 25, 2024
91c2e43
allow only value 1 for spec_tp
wooyeonlee0 Jun 25, 2024
fac7e68
yapf
wooyeonlee0 Jun 25, 2024
271822e
add todo comment
wooyeonlee0 Jun 25, 2024
ae0d7f1
add tests for check that test_skip fails even there's no spec_draft_t…
wooyeonlee0 Jun 25, 2024
b84a070
remove test_skip_speculation from dist tests
wooyeonlee0 Jun 25, 2024
86fda24
yapf
wooyeonlee0 Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ def maybe_create_spec_config(
target_parallel_config: ParallelConfig,
target_dtype: str,
speculative_model: Optional[str],
speculative_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int],
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
Expand Down Expand Up @@ -878,7 +879,6 @@ def maybe_create_spec_config(
# config, in future, we may try refactor it out, and set
# draft related config as None here.
draft_model_config = target_model_config
draft_parallel_config = target_parallel_config
else:
ngram_prompt_lookup_max = 0
ngram_prompt_lookup_min = 0
Expand Down Expand Up @@ -907,9 +907,9 @@ def maybe_create_spec_config(
target_model_config.max_model_len,
))

draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config, speculative_tensor_parallel_size))

return SpeculativeConfig(
draft_model_config,
Expand Down Expand Up @@ -957,16 +957,27 @@ def _maybe_override_draft_max_model_len(

@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig) -> ParallelConfig:
target_parallel_config: ParallelConfig,
speculative_tensor_parallel_size: Optional[int]) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.

This is mostly a copy of the target parallel config. In the future the
draft worker can have a different parallel strategy, e.g. TP=1.
This is mostly a copy of the target parallel config, except the tp_size.
"""

speculative_tensor_parallel_size = (
speculative_tensor_parallel_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if the user provides --speculative-tensor-parallel-size 0, this branch causes unexpected behavior. Can we explicitly guard against this?

Copy link
Contributor Author

@wooyeonlee0 wooyeonlee0 Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch!
To prevent spec_tp from being set as target_tp when given spec_tp is 0, I've changed the code as below:

        if speculative_tensor_parallel_size is None:
            speculative_tensor_parallel_size = target_parallel_config.tensor_parallel_size

In addition, to prevent tp value from being 0, I think we need to make a separate PR to handle that case by adding a check in ParallelConfig._verify_args(). Because It seems to be the same in the --tensor-parallel-size 0 case.

What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'll look at your new changes; no need to completely fix this (just a nit)

or target_parallel_config.tensor_parallel_size)

if speculative_tensor_parallel_size > \
target_parallel_config.tensor_parallel_size:
raise ValueError(
f"{speculative_tensor_parallel_size=} cannot be "
f"larger than {target_parallel_config.tensor_parallel_size}")
wooyeonlee0 marked this conversation as resolved.
Show resolved Hide resolved

draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
tensor_parallel_size=speculative_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.
distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.
Expand Down
31 changes: 30 additions & 1 deletion vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,10 @@ def init_distributed_environment(
global _WORLD
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
if world_size != -1:
assert world_size == len(ranks), (
"given world_size does not match with world_size of torch")
wooyeonlee0 marked this conversation as resolved.
Show resolved Hide resolved

_WORLD = GroupCoordinator(
group_ranks=[ranks],
local_rank=local_rank,
Expand All @@ -558,7 +562,7 @@ def init_distributed_environment(
use_custom_allreduce=False,
)
else:
assert _WORLD.world_size == torch.distributed.get_world_size(), (
assert _WORLD.world_size == world_size, (
wooyeonlee0 marked this conversation as resolved.
Show resolved Hide resolved
"world group already initialized with a different world size")


Expand Down Expand Up @@ -673,6 +677,31 @@ def model_parallel_is_initialized():
return (_TP is not None and _PP is not None)


OVERRIDE_TP_STATE = False


@contextlib.contextmanager
def patch_tensor_parallel_group(world_group, tp_group):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this global variable patching potentially create problem? For example, is it possible that other workers will use this context unknowingly?

Copy link
Contributor Author

@wooyeonlee0 wooyeonlee0 Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current design of speculative decoding, draft and target workers execute sequentially.
So there's no chance of target workers using the patched/overridden context.

But if draft and target worker execute concurrently in the future, the code should be redesigned to prevent states being mixed with each other.

"""Patch the tp group temporarily until this function ends."""
global OVERRIDE_TP_STATE
if OVERRIDE_TP_STATE:
return

OVERRIDE_TP_STATE = True
old_world_group = get_world_group()
old_tp_group = get_tp_group()
global _WORLD, _TP
_WORLD = world_group
_TP = tp_group
try:
yield
finally:
# restore the original state
OVERRIDE_TP_STATE = False
_WORLD = old_world_group
_TP = old_tp_group


def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return get_tp_group().world_size
Expand Down
10 changes: 10 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class EngineArgs:
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
Expand Down Expand Up @@ -534,6 +535,13 @@ def add_cli_args(
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-tensor-parallel-size',
'-spec-tp',
type=int,
default=EngineArgs.speculative_tensor_parallel_size,
help='Number of tensor parallel replicas for '
'the draft model in speculative decoding.')

parser.add_argument(
'--speculative-max-model-len',
Expand Down Expand Up @@ -676,6 +684,8 @@ def create_engine_config(self, ) -> EngineConfig:
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_tensor_parallel_size = \
self.speculative_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,
speculative_disable_by_batch_size=self.
speculative_disable_by_batch_size,
Expand Down
9 changes: 5 additions & 4 deletions vllm/spec_decode/multi_step_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer)
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
Expand All @@ -28,7 +29,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Lazy initialization list.
self._proposer: Top1Proposer
self._proposer: SpeculativeProposer

def init_device(self):
super().init_device()
Expand Down Expand Up @@ -71,9 +72,9 @@ def sampler_output(
sample_len)

# Run model sample_len times.
model_outputs = []
model_outputs: List[SamplerOutput] = []
for _ in range(sample_len):
model_output = super().execute_model(
model_output: List[SamplerOutput] = super().execute_model(
execute_model_req=copied_execute_model_req)
assert (len(model_output) == 1
), "composing multistep workers not supported"
Expand Down
213 changes: 213 additions & 0 deletions vllm/spec_decode/smaller_tp_proposer_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from typing import List, Optional, Set, Tuple, Union

import torch
import torch.distributed

from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import (_ENABLE_CUSTOM_ALL_REDUCE,
GroupCoordinator, get_tp_group,
get_world_group,
patch_tensor_parallel_group)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.worker.worker import Worker

logger = init_logger(__name__)


class SmallerTpProposerWorker(ProposerWorkerBase):
"""Class which allows a speculative draft model to run with smaller tensor
parallel degree than target model.
This reduces the communication overhead of small draft models.

This is implemented by changing vLLM's tensor parallel group to a group of
size temporarily during forward passes of draft models.
"""

@classmethod
def maybe_wrap_worker(cls, worker, draft_parallel_config: ParallelConfig,
target_parallel_config: ParallelConfig, rank: int):
"""Wrap the worker in a SmallerTpProposerWorker if necessary.
"""
draft_tp = draft_parallel_config.tensor_parallel_size
target_tp = target_parallel_config.tensor_parallel_size

if draft_tp == target_tp:
return worker

if draft_tp > target_tp:
raise ValueError(
f"{cls} only supports draft_tp smaller than target_tp."
f"{draft_tp=} {target_tp=}")

# gpu ranks that will generate draft tokens together
ranks = list(range(draft_tp))

if rank in ranks:
logger.info("Wrapping {%s} in {%s}", type(worker), cls)
return cls(worker, ranks)
else:
# for workers not participating in the draft generation
logger.info("Returning dummy worker")
return DummyProposerWorker(worker)

def __init__(self, worker: Union[Worker, ProposerWorkerBase],
ranks: List[int]):
self._worker = worker
self._ranks = ranks
self._world_group = None
self._tp_group = None

def _patch_tensor_parallel_group(self):
return patch_tensor_parallel_group(self._world_group, self._tp_group)

def init_device(self):
"""Initialize the device.

This also creates an additional tensor-parallel process group containing
only a subset of the whole ranks.
"""
local_rank = get_world_group().local_rank
world_backend = torch.distributed.get_backend(
get_world_group().device_group)
tp_backend = torch.distributed.get_backend(get_tp_group().device_group)

self._world_group = GroupCoordinator(
group_ranks=[self._ranks],
local_rank=local_rank,
torch_distributed_backend=world_backend,
use_pynccl=False,
use_custom_allreduce=False,
)
self._tp_group = GroupCoordinator(
group_ranks=[self._ranks],
local_rank=local_rank,
torch_distributed_backend=tp_backend,
use_pynccl=True,
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
)

with self._patch_tensor_parallel_group():
self._worker.init_device()

def set_include_gpu_probs_tensor(self):
self._worker.set_include_gpu_probs_tensor()

def load_model(self):
with self._patch_tensor_parallel_group():
self._worker.load_model()

def determine_num_available_blocks(self):
with self._patch_tensor_parallel_group():
return self._worker.determine_num_available_blocks()

def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int):
with self._patch_tensor_parallel_group():
self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)

@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[Optional[List[SamplerOutput]], bool]:
# do not call _parch_tensor_parallel_group, because
# it's always called after tp_group has already been overridden
return self._worker.sampler_output(execute_model_req, sample_len)

def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
with self._patch_tensor_parallel_group():
return self._worker.get_spec_proposals(execute_model_req)

@torch.inference_mode()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
with self._patch_tensor_parallel_group():
return self._worker.execute_model(execute_model_req)

def get_cache_block_size_bytes(self) -> int:
return self._worker.get_cache_block_size_bytes()

def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError

def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

def list_loras(self) -> Set[int]:
raise NotImplementedError

@property
def max_model_len(self) -> int:
return self._worker.max_model_len

@property
def vocab_size(self) -> int:
return self._worker.vocab_size


class DummyProposerWorker(ProposerWorkerBase):
"""Dummy proposer worker that do nothing.
It's for workers that do not participate in draft generation.
"""

def __init__(
self,
worker: Union[Worker, ProposerWorkerBase],
):
self._worker = worker

def init_device(self):
pass

def load_model(self):
pass

def determine_num_available_blocks(self):
pass

def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int):
pass

def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[List[SamplerOutput], bool]:
return [], True

def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
return SpeculativeProposals(None, None, None)

def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
return []

def get_cache_block_size_bytes(self) -> int:
return 0

def add_lora(self, lora_request: LoRARequest) -> bool:
return False

def remove_lora(self, lora_id: int) -> bool:
return False

def list_loras(self) -> Set[int]:
return set()

@property
def vocab_size(self) -> int:
return self._worker.vocab_size
Loading
Loading