|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
| 3 | +import json |
3 | 4 | import os |
4 | 5 | import time |
5 | 6 | from collections import defaultdict |
6 | 7 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union |
7 | 8 |
|
8 | 9 | import msgspec |
9 | 10 |
|
| 11 | +import vllm.envs as envs |
10 | 12 | import vllm.platforms |
11 | 13 | from vllm.config import ParallelConfig |
12 | 14 | from vllm.executor.msgspec_utils import decode_hook, encode_hook |
@@ -162,6 +164,41 @@ def assert_ray_available(): |
162 | 164 | "`pip install ray`.") from ray_import_err |
163 | 165 |
|
164 | 166 |
|
| 167 | +def serialize_placement_group_to_str(placement_group: "PlacementGroup") -> str: |
| 168 | + """Serialize a placement group to a string. |
| 169 | + FIXME: This should be implemented in Ray. |
| 170 | +
|
| 171 | + Args: |
| 172 | + placement_group: The placement group to serialize. |
| 173 | +
|
| 174 | + Returns: |
| 175 | + A string representation of the placement group. |
| 176 | + """ |
| 177 | + placement_group_data = { |
| 178 | + "id": placement_group.id.hex(), |
| 179 | + "bundle_cache": placement_group.bundle_cache, |
| 180 | + } |
| 181 | + return json.dumps(placement_group_data) |
| 182 | + |
| 183 | + |
| 184 | +def deserialize_placement_group_from_str( |
| 185 | + placement_group_str: str) -> "PlacementGroup": |
| 186 | + """Deserialize a placement group from a string. |
| 187 | + FIXME: This should be implemented in Ray. |
| 188 | +
|
| 189 | + Args: |
| 190 | + placement_group_str: The string representation of the placement group. |
| 191 | +
|
| 192 | + Returns: |
| 193 | + A placement group. |
| 194 | + """ |
| 195 | + placement_group_data = json.loads(placement_group_str) |
| 196 | + return PlacementGroup( |
| 197 | + id=ray._raylet.PlacementGroupID.from_hex(placement_group_data["id"]), |
| 198 | + bundle_cache=placement_group_data["bundle_cache"], |
| 199 | + ) |
| 200 | + |
| 201 | + |
165 | 202 | def _verify_bundles(placement_group: "PlacementGroup", |
166 | 203 | parallel_config: ParallelConfig, device_str: str): |
167 | 204 | """Verify a given placement group has bundles located in the right place. |
@@ -308,12 +345,19 @@ def initialize_ray_cluster( |
308 | 345 |
|
309 | 346 | # Create or get the placement group for worker processes |
310 | 347 | if parallel_config.placement_group: |
| 348 | + logger.info( |
| 349 | + "Using the existing Ray placement group from parallel config") |
311 | 350 | current_placement_group = parallel_config.placement_group |
| 351 | + elif envs.VLLM_RAY_PLACEMENT_GROUP: |
| 352 | + logger.info("Using the existing Ray placement group from " |
| 353 | + "VLLM_RAY_PLACEMENT_GROUP") |
| 354 | + current_placement_group = deserialize_placement_group_from_str( |
| 355 | + envs.VLLM_RAY_PLACEMENT_GROUP) |
312 | 356 | else: |
| 357 | + logger.info("Trying to get the existing Ray placement group") |
313 | 358 | current_placement_group = ray.util.get_current_placement_group() |
314 | 359 |
|
315 | 360 | if current_placement_group: |
316 | | - logger.info("Using the existing placement group") |
317 | 361 |
|
318 | 362 | # We are in a placement group |
319 | 363 | bundles = current_placement_group.bundle_specs |
|
0 commit comments