Skip to content

[Feature] Support Pipeline Parallism in torchrun SPMD offline inference for V1 #17827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 16, 2025

Conversation

luccafong
Copy link
Collaborator

@luccafong luccafong commented May 8, 2025

This PR add support for PP on torchrun offline inference, note this now does not support overlapping mircobatches so not the most efficient way to ublock PP use cases, will improve as follow ups.

torchrun --nnodes 2 --nproc-per-node 2 --rdzv-id=random12345 --rdzv-backend=c10d --rdzv-endpoint=<host:port>  examples/offline_inference/torchrun_example.py

output

--------------------------------------------------
Prompt: 'Hello, my name is'
Generated text: ' Hilary and I have been a full time Licensed Massage Therapist since 199'

--------------------------------------------------
Prompt: 'The president of the United States is'
Generated text: ' a leader of the free world. His actions and statements have a direct impact on'

--------------------------------------------------
Prompt: 'The capital of France is'
Generated text: ' the largest city in the country and a major European metropolis. It is also'

--------------------------------------------------
Prompt: 'The future of AI is'
Generated text: ' not as promising as you might think\nWith a preponderance of artificial intelligence'

Copy link

github-actions bot commented May 8, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels May 8, 2025
@luccafong luccafong marked this pull request as ready for review May 8, 2025 16:29
@luccafong luccafong changed the title Support Pipeline Parallism in torchrun SPMD offline inference for V1 [Feature] Support Pipeline Parallism in torchrun SPMD offline inference for V1 May 8, 2025
@luccafong luccafong requested a review from youkaichao May 9, 2025 17:19
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. The main question is not sure why we need pipeline_parallel_broadcast_output and cannot make it always true.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for efficient pp, we would want different stages of pp from different batches run concurrently. would this broadcast make them serialized? like there will be only one batch and one stage running for the full engine.

if len(get_pp_group().ranks) > 0:
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
model_output_broadcast_data, src=last_rank_in_group)
assert model_output_broadcast_data is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we return earlier if it's not first rank?

Copy link
Collaborator Author

@luccafong luccafong May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still need it for complete sync in current solution

Copy link

mergify bot commented May 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @luccafong.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 12, 2025
luccafong added 4 commits May 12, 2025 10:02
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
@luccafong
Copy link
Collaborator Author

for efficient pp, we would want different stages of pp from different batches run concurrently. would this broadcast make them serialized? like there will be only one batch and one stage running for the full engine.

thanks @youkaichao, right now running torch.run in a pure synced way to align with SPMD, so this is not the ideal solution for efficient PP, only one batch and one stage. I will think about how to design and improve it so we have multi-batches can overlap while compatible with torch.run as a follow up.

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label May 12, 2025
Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed with @luccafong , landing this first to unblock some use case.

Will enhance the perf in the following PR.

luccafong added 2 commits May 12, 2025 15:23
Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Lucia Fang <fanglu@fb.com>
Copy link
Collaborator

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG with a couple comments

last_rank_in_group = pp_group_ranks - 1
if self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0:
model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why we need this broadcast?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, added as comments, now we enable by sync all ranks, will improve to reduce pp bubles in following PR.

Signed-off-by: Lucia Fang <fanglu@fb.com>
@mergify mergify bot added the ci/build label May 13, 2025
Copy link

mergify bot commented May 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @luccafong.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 14, 2025
Signed-off-by: Lucia Fang <fanglu@fb.com>
@mergify mergify bot removed the needs-rebase label May 15, 2025
@DarkLight1337 DarkLight1337 added this to the v0.9.0 milestone May 15, 2025
@luccafong
Copy link
Collaborator Author

CI build failure not related, also pulled latest change, @houseroad

@houseroad
Copy link
Collaborator

Actually v1 test failure may be related: E AssertionError: pipeline model parallel group is not initialized

cc: @luccafong

@luccafong
Copy link
Collaborator Author

Actually v1 test failure may be related: E AssertionError: pipeline model parallel group is not initialized

cc: @luccafong

this is a trunk failure

Added a fix here:
#18223

@houseroad houseroad enabled auto-merge (squash) May 15, 2025 22:59
@simon-mo simon-mo disabled auto-merge May 16, 2025 05:28
@simon-mo simon-mo merged commit 3d2779c into vllm-project:main May 16, 2025
86 of 90 checks passed
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…ce for V1 (vllm-project#17827)

Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
…ce for V1 (vllm-project#17827)

Signed-off-by: Lucia Fang <fanglu@fb.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants