Skip to content

[Doc] Documentation for distributed inference #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,6 @@ cython_debug/

# Python pickle files
*.pkl

# Sphinx documentation
_build/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ vLLM is fast with:

- State-of-the-art serving throughput
- Efficient management of attention key and value memory with **PagedAttention**
- Dynamic batching of incoming requests
- Continuous batching of incoming requests
- Optimized CUDA kernels

vLLM is flexible and easy to use with:
Expand Down
14 changes: 12 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ vLLM is fast with:

* State-of-the-art serving throughput
* Efficient management of attention key and value memory with **PagedAttention**
* Dynamic batching of incoming requests
* Continuous batching of incoming requests
* Optimized CUDA kernels

vLLM is flexible and easy to use with:
Expand All @@ -40,7 +40,11 @@ vLLM is flexible and easy to use with:
* Streaming outputs
* OpenAI-compatible API server

For more information, please refer to our `blog post <https://vllm.ai>`_.
For more information, check out the following:

* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.



Documentation
Expand All @@ -53,6 +57,12 @@ Documentation
getting_started/installation
getting_started/quickstart

.. toctree::
:maxdepth: 1
:caption: Serving

serving/distributed_serving

.. toctree::
:maxdepth: 1
:caption: Models
Expand Down
38 changes: 38 additions & 0 deletions docs/source/serving/distributed_serving.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
.. _distributed_serving:

Distributed Inference and Serving
=================================

vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with `Ray <https://github.com/ray-project/ray>`_. To run distributed inference, install Ray with:

.. code-block:: console

$ pip install ray

To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:

.. code-block:: python

from vllm import LLM
llm = LLM("facebook/opt-13b", tensor_parallel_size=4)
output = llm.generate("San Franciso is a")

To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument when starting the server. For example, to run API server on 4 GPUs:

.. code-block:: console

$ python -m vllm.entrypoints.api_server \
$ --model facebook/opt-13b \
$ --tensor-parallel-size 4

To scale vLLM beyond a single machine, start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:

.. code-block:: console

$ # On head node
$ ray start --head

$ # On worker nodes
$ ray start --address=<ray-head-address>

After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` to the number of GPUs to be the total number of GPUs across all machines.