Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor

if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
Expand Down Expand Up @@ -1240,6 +1240,18 @@ def create_engine_config(
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
)

# Get the current placement group if Ray is initialized and
# we are in a Ray actor. If so, then the placement group will be
# passed to spawned processes.
placement_group = None
if is_in_ray_actor():
import ray

# This call initializes Ray automatically if it is not initialized,
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()

parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
Expand All @@ -1252,6 +1264,7 @@ def create_engine_config(
self.tokenizer_pool_extra_config,
),
ray_workers_use_nsight=self.ray_workers_use_nsight,
placement_group=placement_group,
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
Expand Down
4 changes: 2 additions & 2 deletions vllm/executor/multiproc_worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import _check_multiproc_method, get_mp_context, run_method
from vllm.utils import _maybe_force_spawn, get_mp_context, run_method

logger = init_logger(__name__)

Expand Down Expand Up @@ -291,7 +291,7 @@ def set_multiprocessing_worker_envs(parallel_config):
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""

_check_multiproc_method()
_maybe_force_spawn()

# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
Expand Down
21 changes: 13 additions & 8 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,9 @@ def initialize_ray_cluster(
assert_ray_available()
from vllm.platforms import current_platform

# Connect to a ray cluster.
if current_platform.is_rocm() or current_platform.is_xpu():
if ray.is_initialized():
logger.info("Ray is already initialized. Skipping Ray initialization.")
elif current_platform.is_rocm() or current_platform.is_xpu():
# Try to connect existing ray instance and create a new one if not found
try:
ray.init("auto", ignore_reinit_error=True)
Expand All @@ -299,19 +300,21 @@ def initialize_ray_cluster(
else:
ray.init(address=ray_address, ignore_reinit_error=True)

if parallel_config.placement_group:
# Placement group is already set.
return

device_str = current_platform.ray_device_key
if not device_str:
raise ValueError(
f"current platform {current_platform.device_name} does not "
"support ray.")

# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
# Create or get the placement group for worker processes
if parallel_config.placement_group:
current_placement_group = parallel_config.placement_group
else:
current_placement_group = ray.util.get_current_placement_group()

if current_placement_group:
logger.info("Using the existing placement group")

# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
Expand All @@ -331,6 +334,8 @@ def initialize_ray_cluster(
f"Required number of devices: {parallel_config.world_size}. "
f"Total number of devices: {device_bundles}.")
else:
logger.info("No current placement group found. "
"Creating a new placement group.")
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
# Log a warning message and delay resource allocation failure response.
# Avoid immediate rejection to allow user-initiated placement group
Expand Down
48 changes: 38 additions & 10 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2142,20 +2142,48 @@ def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]:
ctx.destroy(linger=0)


def _check_multiproc_method():
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"troubleshooting.html#python-multiprocessing "
"for more information.")
def is_in_ray_actor():
"""Check if we are in a Ray actor."""

try:
import ray
return (ray.is_initialized()
and ray.get_runtime_context().get_actor_id() is not None)
except ImportError:
return False


def _maybe_force_spawn():
"""Check if we need to force the use of the `spawn` multiprocessing start
method.
"""
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") == "spawn":
return

reason = None
if cuda_is_initialized():
reason = "CUDA is initialized"
elif is_in_ray_actor():
reason = "In a Ray actor and can only be spawned"

if reason is not None:
logger.warning(
"We must use the `spawn` multiprocessing start method. "
"Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"troubleshooting.html#python-multiprocessing "
"for more information. Reason: %s", reason)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"


def get_mp_context():
_check_multiproc_method()
"""Get a multiprocessing context with a particular method (spawn or fork).
By default we follow the value of the VLLM_WORKER_MULTIPROC_METHOD to
determine the multiprocessing method (default is fork). However, under
certain conditions, we may enforce spawn and override the value of
VLLM_WORKER_MULTIPROC_METHOD.
"""
_maybe_force_spawn()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)

Expand Down