Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions examples/offline_inference/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,17 @@

class MyLLM(LLM):

def __init__(self, *args, **kwargs):
def __init__(self, *args, pg_name, ray_namespace, **kwargs):
# a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES
# at the top-level
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
# set the placement group name for vLLM to use
os.environ['VLLM_RAY_PG_NAME'] = pg_name
# set the ray namespace for vLLM to use
os.environ['VLLM_RAY_NAMESPACE'] = ray_namespace
# set the ray address for vLLM to use
os.environ['RAY_ADDRESS'] = ray.get_runtime_context().gcs_address
super().__init__(*args, **kwargs)


Expand All @@ -47,10 +53,13 @@ def __init__(self, *args, **kwargs):
GPU 2. For the details on how to use ray, please refer to the ray
documentation https://docs.ray.io/en/latest/ .
"""
PG_NAME = "pg_inference"
RAY_NAMESPACE = "rlhf"

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()
ray.init(namespace=RAY_NAMESPACE)

pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2, name=PG_NAME)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
Expand All @@ -71,6 +80,8 @@ def __init__(self, *args, **kwargs):
worker_extension_cls="rlhf_utils.WorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
pg_name=PG_NAME,
ray_namespace=RAY_NAMESPACE,
)

# Generate texts from the prompts.
Expand Down
9 changes: 9 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,15 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_RAY_BUNDLE_INDICES":
lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""),

# Ray placement group name, if it is set, it can control the
# placement group for vLLM to use.
"VLLM_RAY_PG_NAME":
lambda: os.getenv("VLLM_RAY_PG_NAME", None),

# Ray namespace, if it is set, it can control the namespace for vLLM to use.
"VLLM_RAY_NAMESPACE":
lambda: os.getenv("VLLM_RAY_NAMESPACE", None),

# When on a Nvidia GPU aligns single entries (within a page) so they are 256
# byte aligned for better performance, this increases the memory usage of
# the cache. Currently this only affects MLA that results in non-256
Expand Down
28 changes: 22 additions & 6 deletions vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import msgspec

import vllm.platforms
from vllm import envs
from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
Expand Down Expand Up @@ -283,12 +284,14 @@ def initialize_ray_cluster(
"""
assert_ray_available()
from vllm.platforms import current_platform

# Connect to a ray cluster.
if current_platform.is_rocm() or current_platform.is_xpu():
if ray.is_initialized():
logger.info("Ray is already initialized. Skipping Ray initialization.")
elif current_platform.is_rocm() or current_platform.is_xpu():
# Try to connect existing ray instance and create a new one if not found
try:
ray.init("auto", ignore_reinit_error=True)
ray.init("auto",
ignore_reinit_error=True,
namespace=envs.VLLM_RAY_NAMESPACE)
except ConnectionError:
logger.warning(
"No existing RAY instance detected. "
Expand All @@ -297,7 +300,9 @@ def initialize_ray_cluster(
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else:
ray.init(address=ray_address, ignore_reinit_error=True)
ray.init(address=ray_address,
ignore_reinit_error=True,
namespace=envs.VLLM_RAY_NAMESPACE)

if parallel_config.placement_group:
# Placement group is already set.
Expand All @@ -310,8 +315,17 @@ def initialize_ray_cluster(
"support ray.")

# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if envs.VLLM_RAY_PG_NAME:
# the placement group is specified by the user
logger.info(
"Looking for the placement group specified by"
" VLLM_RAY_PG_NAME: %s", envs.VLLM_RAY_PG_NAME)
current_placement_group = ray.util.get_placement_group(
envs.VLLM_RAY_PG_NAME)
else:
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
logger.info("Using current placement group")
# We are in a placement group
bundles = current_placement_group.bundle_specs
# Verify that we can use the placement group.
Expand All @@ -331,6 +345,8 @@ def initialize_ray_cluster(
f"Required number of devices: {parallel_config.world_size}. "
f"Total number of devices: {device_bundles}.")
else:
logger.info("No current placement group found. "
"Creating a new placement group.")
num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
# Log a warning message and delay resource allocation failure response.
# Avoid immediate rejection to allow user-initiated placement group
Expand Down
11 changes: 11 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2174,6 +2174,17 @@ def _check_multiproc_method():
"troubleshooting.html#python-multiprocessing "
"for more information.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
try:
import ray
if (ray.is_initialized()
and ray.get_runtime_context().get_actor_id() is not None):
logger.info(
"vLLM is running as a Ray actor. "
"Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn', "
"because Ray process can only be spawned, but not forked.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
except ImportError:
pass


def get_mp_context():
Expand Down