Skip to content

Commit

Permalink
[Core] Refactor executor classes for easier inheritance (vllm-project…
Browse files Browse the repository at this point in the history
…#7673)

[Core] Refactor executor classes to make it easier to inherit GPUExecutor (vllm-project#7673)
  • Loading branch information
jikunshang authored Aug 20, 2024
1 parent ad28a74 commit b6f99a6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
27 changes: 16 additions & 11 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ def _get_worker_kwargs(
observability_config=self.observability_config,
)

def _get_worker_module_and_class(self) -> Tuple[str, str]:
if self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_worker"
worker_class_name = "MultiStepWorker"
elif self.speculative_config:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
return (worker_module_name, worker_class_name)

def _get_create_worker_kwargs(
self,
local_rank: int = 0,
Expand All @@ -70,17 +82,10 @@ def _get_create_worker_kwargs(
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method)

if self.scheduler_config.is_multi_step:
worker_kwargs.update(
worker_module_name="vllm.worker.multi_step_worker",
worker_class_name="MultiStepWorker")
elif self.speculative_config:
worker_kwargs.update(
worker_module_name="vllm.spec_decode.spec_decode_worker",
worker_class_name="create_spec_worker")
else:
worker_kwargs.update(worker_module_name="vllm.worker.worker",
worker_class_name="Worker")
(worker_module_name,
worker_class_name) = self._get_worker_module_and_class()
worker_kwargs.update(worker_module_name=worker_module_name,
worker_class_name=worker_class_name)

return worker_kwargs

Expand Down
21 changes: 11 additions & 10 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,22 +91,19 @@ def _configure_ray_workers_use_nsight(self,
return ray_remote_kwargs

def _get_worker_wrapper_args(self) -> Dict[str, Any]:
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
elif self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_worker"
worker_class_name = "MultiStepWorker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
(worker_module_name,
worker_class_name) = self._get_worker_module_and_class()

return dict(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)

# child class could overwrite this to return actual env vars.
def _get_env_vars_to_be_updated(self):
return self._env_vars_for_all_workers

def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if (self.parallel_config.tensor_parallel_size == 1
Expand Down Expand Up @@ -231,8 +228,12 @@ def sort_by_driver_then_worker_ip(worker):
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (node_id, _) in worker_node_and_gpu_ids]

self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)

self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
all_args=self._get_env_vars_to_be_updated())

if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
Expand Down

0 comments on commit b6f99a6

Please sign in to comment.