Skip to content

Commit

Permalink
minor refine
Browse files Browse the repository at this point in the history
  • Loading branch information
wooyeonlee0 committed Jun 12, 2024
1 parent a96e720 commit db39576
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,14 +941,17 @@ def _maybe_override_draft_max_model_len(
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_tensor_parallel_size: int) -> ParallelConfig:
speculative_tensor_parallel_size: Optional[int]) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config.
"""

_speculative_tensor_parallel_size = speculative_tensor_parallel_size or target_parallel_config.tensor_parallel_size

draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
tensor_parallel_size=_speculative_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.
distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.
Expand All @@ -961,9 +964,6 @@ def create_draft_parallel_config(
placement_group=target_parallel_config.placement_group,
)

if speculative_tensor_parallel_size is not None:
draft_parallel_config.tensor_parallel_size = speculative_tensor_parallel_size

return draft_parallel_config

def __init__(
Expand Down

0 comments on commit db39576

Please sign in to comment.