Skip to content

[TPU] Support multi-host inference #7457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 13, 2024
Merged

[TPU] Support multi-host inference #7457

merged 3 commits into from
Aug 13, 2024

Conversation

WoosukKwon
Copy link
Collaborator

Changing global rank and world size into local rank and local world size.

@WoosukKwon WoosukKwon added the tpu Related to Google TPUs label Aug 13, 2024
@WoosukKwon WoosukKwon requested a review from youkaichao August 13, 2024 05:04
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@WoosukKwon WoosukKwon requested a review from youkaichao August 13, 2024 05:33
@youkaichao
Copy link
Member

In ray gpu executor, there are these lines:

def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = ray.get(worker.get_node_ip.remote())
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)

to make sure the worker index aligns with machine boundary. you might need it in TPU, too. Otherwise local ranks can be wrong. for example, rank 0, 1, 2, 4 in one node, and 3, 5, 6, 7 in another node.

@WoosukKwon
Copy link
Collaborator Author

@youkaichao Can you explain more?

@youkaichao
Copy link
Member

say you have 2 nodes, 8 TPUs in total.

ray actors are launched one by one. when you launch the first actor, it might live in node 0; when you launch the second actor, it might live in node 1. If you use the index of worker as global rank, then it will cause a problem.

@WoosukKwon
Copy link
Collaborator Author

@youkaichao Thanks for the explanation. Let me merge this PR first as there are users waiting for this and the scope of this PR is isolated to the TPU backend. I will address your comment in a followup PR.

@WoosukKwon WoosukKwon merged commit a08df83 into main Aug 13, 2024
17 of 21 checks passed
@WoosukKwon WoosukKwon deleted the tpu-multi-host branch August 13, 2024 23:31
@jaisong123
Copy link

Hi @WoosukKwon - thanks for the office hours today!

And thanks for your hard work on this! I’m eager to start using the multi-host inference support on TPUs. Do you know when this feature will be available for general use?

Thanks again!

@sparsh35
Copy link

sparsh35 commented Aug 27, 2024

And thanks for your hard work on this! I’m eager to start using the multi-host inference support on TPUs. Do you know when this feature will be available for general use?

I think it should work now , though I haven't tried it yet especially for bigger model like mistral large v2 , any doc or tutorial would help, @WoosukKwon

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Signed-off-by: Alvant <alvasian@yandex.ru>
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tpu Related to Google TPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants