Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Pipeline parallel with Ray ADAG #6837

Merged
merged 4 commits into from
Aug 2, 2024

Conversation

ruisearch42
Copy link
Contributor

@ruisearch42 ruisearch42 commented Jul 26, 2024

Support pipeline-parallelism with Ray accelerated DAG.

Benchmarking result

“avg latency” column format: adag_nccl, adag_shm, mp, ray (baseline)

  • adag_nccl: ADAG backend, with NCCL communication between PP stages

  • adag_shm: ADAG backend, with shared memory communication between PP stages

  • mp: multiprocessing backend; (not supported for multi-nodes, using N/A to indicate) 

  • ray: ray backend, used as baseline

“% comparison” column format: (adag_nccl / ray) * 100%, (adag_shm / ray) * 100%, (mp / ray) * 100%

  • The three comparisons use ray as baseline

  • < 100%: latency is better

  • > 100%: latency is worse


Nodes GPU Setup Model input_len output_len qps avg latency % comparison
Single V100 PP=2, TP=2 Llama-2-7b-chat-hf 32 32 3 17.8, 16.8, 23.4, 23.0 77.4%, 73.0%, 101.7%
Single V100 PP=2, TP=2 Llama-2-7b-chat-hf 16 16 3 15.4, 16.2, 22.6, 25.6 60.2%, 63.2%, 88.3%
Multi A10 PP=2, TP=4 Llama-2-13b-hf 16 16 3 19.9, 19.3, N/A, 23.6 84.3%, 81.8%, N/A

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@ruisearch42
Copy link
Contributor Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 26, 2024
@ruisearch42 ruisearch42 force-pushed the pp-adag-uni branch 2 times, most recently from 038f7cb to d0a7250 Compare July 27, 2024 00:26
@ruisearch42 ruisearch42 marked this pull request as ready for review July 29, 2024 15:01
Copy link
Contributor

@stephanie-wang stephanie-wang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!

requirements-cuda.txt Outdated Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Outdated Show resolved Hide resolved
@stephanie-wang
Copy link
Contributor

@andoorve would be great to get your review as well.

@andoorve
Copy link
Collaborator

Will take a look today/tomorrow!

vllm/envs.py Outdated Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Show resolved Hide resolved
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])

# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to somehow integrate this with the previous self.pp_tp_workers? If not no worries

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I attempted twice and it broke either driver mode or adag :( . I think there are pros and cons whether to unify with the same code path. If there is consensus we should go with unification, I can try harder to achieve that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a super huge deal since it's limited to this backend. But if it's possible to do so with a little work I think it cleans up the code a little bit. Can ask for @youkaichao opinion as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @youkaichao, do you have a preference whether we separate the implementations of spmd (aDAG) vs non-spmd or unify them at this point? I had the impression you may perfer keeping them separate for now but wanted to check.

requirements-cuda.txt Outdated Show resolved Hide resolved
tests/distributed/test_pipeline_parallel.py Show resolved Hide resolved
vllm/envs.py Outdated Show resolved Hide resolved
vllm/envs.py Outdated Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Outdated Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Show resolved Hide resolved
vllm/executor/ray_gpu_executor.py Outdated Show resolved Hide resolved
vllm/executor/ray_utils.py Show resolved Hide resolved
vllm/worker/worker_base.py Outdated Show resolved Hide resolved
@ruisearch42 ruisearch42 force-pushed the pp-adag-uni branch 3 times, most recently from 327111e to 4fd1525 Compare July 30, 2024 15:50
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! @andoorve can you also give a quick approval if it looks okay?

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! @andoorve can you also give a quick approval if it looks okay?

requirements-cuda.txt Outdated Show resolved Hide resolved
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
@andoorve
Copy link
Collaborator

andoorve commented Aug 1, 2024

Yes @rkooo567 give me a little time to re-review

@youkaichao
Copy link
Member

sorry, didn't notice that you pinged me in this PR. will take a look later.

Copy link
Collaborator

@andoorve andoorve left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkooo567 I took a look again, only 2 things:

  1. If @ruisearch42 can confirm this has been tested locally:

Awesome!! Just a small thing you might want to run the multinode tests a few times. This passed like 50% of the time, as the failures we saw before adding extra logic for worker_ranks were intermittent

  1. If @youkaichao can respond to the comments above where we've pinged

Code itself LGTM

@rkooo567
Copy link
Collaborator

rkooo567 commented Aug 1, 2024

@youkaichao can you also take a look at this PR #6903 (it is not ready at all, but if you review the general approach that would be great)? I think this PR is pretty self-contained with ray backend, but the other PR will be more larger scope

@rkooo567
Copy link
Collaborator

rkooo567 commented Aug 2, 2024

@youkaichao since the change is pretty self contained, I am merging it. Please let us know if you have any feedback! We will follow up immediately

@rkooo567 rkooo567 merged commit 0530889 into vllm-project:main Aug 2, 2024
63 checks passed
sfc-gh-mkeralapura pushed a commit to sfc-gh-mkeralapura/vllm that referenced this pull request Aug 12, 2024
Support pipeline-parallelism with Ray accelerated DAG.

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
Support pipeline-parallelism with Ray accelerated DAG.

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Support pipeline-parallelism with Ray accelerated DAG.

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants