Skip to content

Conversation

@hijkzzz
Copy link
Contributor

@hijkzzz hijkzzz commented Mar 22, 2025

No description provided.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@youkaichao youkaichao changed the title Fix arg_utils.py for OpenRLHF using vLLM V1 engine Fix v1 supported oracle for worker-cls and worker-extension-cls Mar 22, 2025
@youkaichao
Copy link
Member

youkaichao commented Mar 22, 2025

full error log:

INFO 03-22 01:13:06 [__init__.py:256] Automatically detected platform cuda.
2025-03-22 01:13:12,038 INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
(pid=1621476) INFO 03-22 01:13:17 [__init__.py:256] Automatically detected platform cuda.
(MyLLM pid=1621476) INFO 03-22 01:13:24 [config.py:585] This model supports multiple tasks: {'reward', 'embed', 'classify', 'score', 'generate'}. Defaulting to 'generate'.
(MyLLM pid=1621476) INFO 03-22 01:13:24 [config.py:1693] Chunked prefill is enabled with max_num_batched_tokens=16384.
(MyLLM pid=1621476) WARNING 03-22 01:13:24 [cuda.py:96] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
(MyLLM pid=1621476) WARNING 03-22 01:13:26 [utils.py:2176] We must use the `spawn` multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/getting_started/troubleshooting.html#python-multiprocessing for more information. Reason: In a Ray actor and can only be spawned
(MyLLM pid=1621476) INFO 03-22 01:13:29 [__init__.py:256] Automatically detected platform cuda.
(MyLLM pid=1621476) INFO 03-22 01:13:31 [core.py:54] Initializing a V1 LLM engine (v0.7.3.dev876+g2fa0e1396) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=facebook/opt-125m, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}
(MyLLM pid=1621476) 2025-03-22 01:13:33,275     INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8266 
(MyLLM pid=1621476) INFO 03-22 01:13:34 [ray_utils.py:316] Using the existing placement group
...
(MyLLM pid=1621476) ERROR 03-22 01:13:34 [core.py:343]     ray._private.state.state.placement_group_table(pg_id)["bundles"].values()
(MyLLM pid=1621476) ERROR 03-22 01:13:34 [core.py:343] KeyError: 'bundles'

@comaniac it seems passing the placement group is not enough. the new process fails to recognize the old ray cluster.

2d1e722 should fix it.

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Comment on lines -1463 to -1472
if self.worker_cls != EngineArgs.worker_cls:
_raise_or_fallback(feature_name="--worker-cls",
recommend_to_remove=False)
return False

if self.worker_extension_cls != EngineArgs.worker_extension_cls:
_raise_or_fallback(feature_name="--worker-extension-cls",
recommend_to_remove=False)
return False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @robertgshaw2-redhat for visibility.

@mergify mergify bot added the ci/build label Mar 22, 2025
- pushd ../examples/offline_inference
- python3 rlhf.py
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 rlhf.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in v1, LLM.collective_rpc is broken in the default case, because self.llm_engine.model_executor is in a different process.

cc @njhill @robertgshaw2-redhat if you can help fix it. i'm disabling VLLM_ENABLE_V1_MULTIPROCESSING for the test right now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should fall back to this automatically if the user selects --worker-cls?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need to fall back. worker-cls and rlhf can be separate things.

vllm/utils.py Outdated
Comment on lines 2173 to 2177
# even we choose to spawn, we need to pass to the subprocess
# the ray address, so that it knows to connect to the ray cluster.
# env vars are inherited by subprocesses, even if we use spawn.
import ray
os.environ["RAY_ADDRESS"] = ray.get_runtime_context().gcs_address
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a follow-up after #14705

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah make sense. I only tested the Ray Serve use case before. cc @ruisearch42

Copy link
Collaborator

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

# even we choose to spawn, we need to pass to the subprocess
# the ray address, so that it knows to connect to the ray cluster.
# env vars are inherited by subprocesses, even if we use spawn.
import ray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ray is not always installed for vllm, we need a try-except like is_in_ray_actor()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is guarded under is_in_ray_actor(), so it is safe.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's right, realized it afterwards

Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao youkaichao merged commit 0661cfe into vllm-project:main Mar 23, 2025
5 of 8 checks passed
erictang000 pushed a commit to erictang000/vllm that referenced this pull request Mar 25, 2025
…-project#15324)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
wrmedford pushed a commit to wrmedford/vllm that referenced this pull request Mar 26, 2025
…-project#15324)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Wes Medford <wryanmedford@gmail.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…-project#15324)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
…-project#15324)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
…-project#15324)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…-project#15324)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants