Skip to content

Conversation

@ruisearch42
Copy link
Collaborator

@ruisearch42 ruisearch42 commented Mar 13, 2025

By default VLLM_WORKER_MULTIPROC_METHOD is set to fork. However, forking a Ray actor will cause undefined behavior, this leads to hangs in placement group methods using V1.

This PR fixes the issue by detecting if vLLM is used as a Ray actor, and if so, uses spawn instead of fork for new process creation.

This PR also skips ray initialization if ray is already initialized.

Tested with VLLM_USE_V1=1 python /home/ubuntu/vllm/examples/offline_inference/rlhf.py , server started and served requests:

(MyLLM pid=2966292) INFO 03-13 16:18:58 [core.py:51] Initializing a V1 LLM engine (v0.6.6.dev668+gc4c11eae.d20250209) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=facebook/opt-125m, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}

Prompt: 'Hello, my name is', Generated text: ' J.C. and I am a student at the University of California, Berkeley'
Prompt: 'The president of the United States is', Generated text: " not a racist. He is a racist.\nHe's a racist because he"
Prompt: 'The capital of France is', Generated text: ' the capital of the French Republic.\n\nThe capital of France is the capital'
Prompt: 'The future of AI is', Generated text: ' in the hands of the people.\n\nThe future of AI is in the'

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
"support ray.")

# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You won't be able to get the current placement group with spawn right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we will need to pass via PG_NAME, as in #14410

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to not introduce more environment variables if possible. Why can't we just pass the placement group?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

env vars are automatically inheritted, even if the process is spawned.

if we pass placement group as argument, then we need to explicitly pass it somewhere, and it might vary depending on the entrypoint function of the new process.

I'm fine with adding new env vars, as long as they are well scoped.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one drawback of this approach, is that we need the placement group to be named and in a namespace, while in general it is not the case for placement groups.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to get some ids from the placement group, and then directly construct placement group from the ids? it seems to be a pretty common use case for ray, and i think it should have some solutions other than named placement group.

Copy link
Collaborator Author

@ruisearch42 ruisearch42 Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment makes sense, that would be ideal. One the other hand, the fundamental reason we needed named placement group and namespace is we are in a new ray job after the spawn: the usage of namespace and name is for isolation purpose. And the way to tackle that is to use Ray to manage process resources, rather than using Python multi-processing. That will be a somewhat longer term solution, as it needs quite some changes in vLLM. So unfortunately using ID and not name/namespace is not really viable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what @comaniac did in #14705 , is already implementing this kind of functionality i think, it passes the placement group object across processes. and we should be able to do it via environment variable so that we don't need to pass it to specific functions.

for example, inside the _check_multiproc_method function, if you detect ray placement group, not only set VLLM_WORKER_MULTIPROC_METHOD, but also set VLLM_RAY_PG_HANDLE to contain some information from the placement group.

and then, in ray_utils.py, if VLLM_RAY_PG_HANDLE is set, you parse the information and create the placement group object again.

Copy link
Collaborator

@comaniac comaniac Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue is we may not be able to pass the placement group object via env var, especially when the placement group is not created by users. For example when using Ray Serve to launch vllm, the placement group is created by Ray Serve, and it won't pass placement group it created to an env var. Instead, "get_current_placement_group()" is a more desire pattern of sharing placement group

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also to Kaichao's comment:

Placement group is a non-string object, I don't think we can pass it via environment variables? Did you mean to serialize/pickle it first?

Also, by VLLM_RAY_PG_HANDLE, did you mean the imaginary "ID"? There is no PG handle. Ray API only supports namespace + name: that's the only "handle" supported.

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
@mergify mergify bot added the documentation Improvements or additions to documentation label Mar 13, 2025
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants