Skip to content

Commit

Permalink
[Bugfix] [SpecDecode] Default speculative_draft_tensor_parallel_size …
Browse files Browse the repository at this point in the history
…to 1 when using MLPSpeculator (vllm-project#7105)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
  • Loading branch information
tdoublep authored and sfc-gh-mkeralapura committed Aug 12, 2024
1 parent 7c5bc0a commit 3022472
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def maybe_create_spec_config(
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config,
speculative_draft_tensor_parallel_size))
speculative_draft_tensor_parallel_size, draft_hf_config))

if num_speculative_tokens is None:
raise ValueError(
Expand Down Expand Up @@ -1136,15 +1136,23 @@ def _maybe_override_draft_max_model_len(
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int]
speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
if speculative_draft_tensor_parallel_size is None:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
if draft_hf_config.model_type == "mlp_speculator":
speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"MLPSpeculator cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1")
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
raise ValueError(
Expand Down

0 comments on commit 3022472

Please sign in to comment.