Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/offline_inference/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@

class MyLLM(LLM):

def __init__(self, *args, **kwargs):
def __init__(self, *args, pg_name, **kwargs):
# a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES
# at the top-level
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
# set the placement group name for vLLM to use
os.environ['VLLM_RAY_PG_NAME'] = pg_name
# set the ray address for vLLM to use
os.environ['RAY_ADDRESS'] = ray.get_runtime_context().gcs_address
super().__init__(*args, **kwargs)


Expand All @@ -50,7 +54,7 @@ def __init__(self, *args, **kwargs):
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()

pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2, name="inference")
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group=pg_inference,
Expand All @@ -71,6 +75,7 @@ def __init__(self, *args, **kwargs):
worker_extension_cls="rlhf_utils.WorkerExtension",
tensor_parallel_size=2,
distributed_executor_backend="ray",
pg_name="inference",
)

# Generate texts from the prompts.
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_RAY_PG_NAME: Optional[str] = None
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
VLLM_DP_RANK: int = 0
Expand Down Expand Up @@ -590,6 +591,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
"VLLM_RAY_BUNDLE_INDICES":
lambda: os.getenv("VLLM_RAY_BUNDLE_INDICES", ""),

# ray placement group name, if it is set, it can control the
# placement group for vLLM to use.
"VLLM_RAY_PG_NAME":
lambda: os.getenv("VLLM_RAY_PG_NAME", None),

# When on a Nvidia GPU aligns single entries (within a page) so they are 256
# byte aligned for better performance, this increases the memory usage of
# the cache. Currently this only affects MLA that results in non-256
Expand Down
11 changes: 10 additions & 1 deletion vllm/executor/ray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import msgspec

import vllm.envs as envs
import vllm.platforms
from vllm.config import ParallelConfig
from vllm.executor.msgspec_utils import decode_hook, encode_hook
Expand Down Expand Up @@ -304,7 +305,15 @@ def initialize_ray_cluster(
"support ray.")

# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if envs.VLLM_RAY_PG_NAME:
# the placement group is specified by the user
logger.info(
"Looking for the placement group specified by"
" VLLM_RAY_PG_NAME: %s", envs.VLLM_RAY_PG_NAME)
current_placement_group = ray.util.get_placement_group(
envs.VLLM_RAY_PG_NAME)
else:
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group:
# We are in a placement group
bundles = current_placement_group.bundle_specs
Expand Down