File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -941,14 +941,17 @@ def _maybe_override_draft_max_model_len(
941
941
@staticmethod
942
942
def create_draft_parallel_config (
943
943
target_parallel_config : ParallelConfig ,
944
- speculative_tensor_parallel_size : int ) -> ParallelConfig :
944
+ speculative_tensor_parallel_size : Optional [ int ] ) -> ParallelConfig :
945
945
"""Create a parallel config for use by the draft worker.
946
946
947
947
This is mostly a copy of the target parallel config.
948
948
"""
949
+
950
+ _speculative_tensor_parallel_size = speculative_tensor_parallel_size or target_parallel_config .tensor_parallel_size
951
+
949
952
draft_parallel_config = ParallelConfig (
950
953
pipeline_parallel_size = target_parallel_config .pipeline_parallel_size ,
951
- tensor_parallel_size = target_parallel_config . tensor_parallel_size ,
954
+ tensor_parallel_size = _speculative_tensor_parallel_size ,
952
955
distributed_executor_backend = target_parallel_config .
953
956
distributed_executor_backend ,
954
957
max_parallel_loading_workers = target_parallel_config .
@@ -961,9 +964,6 @@ def create_draft_parallel_config(
961
964
placement_group = target_parallel_config .placement_group ,
962
965
)
963
966
964
- if speculative_tensor_parallel_size is not None :
965
- draft_parallel_config .tensor_parallel_size = speculative_tensor_parallel_size
966
-
967
967
return draft_parallel_config
968
968
969
969
def __init__ (
You can’t perform that action at this time.
0 commit comments