Skip to content

Commit

Permalink
[Doc] Improve debugging documentation (vllm-project#9204)
Browse files Browse the repository at this point in the history
Signed-off-by: Rafael Vasquez <rafvasq21@gmail.com>
  • Loading branch information
rafvasq authored Oct 10, 2024
1 parent b37dac0 commit 7326a7c
Showing 1 changed file with 53 additions and 36 deletions.
89 changes: 53 additions & 36 deletions docs/source/getting_started/debugging.rst
Original file line number Diff line number Diff line change
@@ -1,32 +1,53 @@
.. _debugging:

===============
Debugging Tips
===============

Debugging hang/crash issues
---------------------------
This document outlines some debugging strategies you can consider. If you think you've discovered a bug, please `search existing issues <https://github.com/vllm-project/vllm/issues?q=is%3Aissue>`_ first to see if it has already been reported. If not, please `file a new issue <https://github.com/vllm-project/vllm/issues/new/choose>`_, providing as much relevant information as possible.

.. note::

Once you've debugged a problem, remember to turn off any debugging environment variables defined, or simply start a new shell to avoid being affected by lingering debugging settings. Otherwise, the system might be slow with debugging functionalities left activated.

When an vLLM instance hangs or crashes, it is very difficult to debug the issue. But wait a minute, it is also possible that vLLM is doing something that indeed takes a long time:
Hangs downloading a model
----------------------------------------
If the model isn't already downloaded to disk, vLLM will download it from the internet which can take time and depend on your internet connection.
It's recommended to download the model first using the `huggingface-cli <https://huggingface.co/docs/huggingface_hub/en/guides/cli>`_ and passing the local path to the model to vLLM. This way, you can isolate the issue.

- **Downloading a model**: Do you have the model already downloaded in your disk? If not, vLLM will download the model from the internet, which can take a long time. Be sure to check the internet connection. It would be better to download the model first using `huggingface-cli <https://huggingface.co/docs/huggingface_hub/en/guides/cli>`_ and then use the local path to the model. This way, you can isolate the issue.
- **Loading the model from disk**: If the model is large, it can take a long time to load the model from disk. Please take care of the location you store the model. Some clusters have shared filesystems across nodes, e.g. distributed filesystem or network filesystem, which can be slow. It would be better to store the model in a local disk. In addition, please also watch the CPU memory usage. When the model is too large, it might take much CPU memory, which can slow down the operating system because it needs to frequently swap memory between the disk and the memory.
- **Tensor parallel inference**: If the model is too large to fit in a single GPU, you might want to use tensor parallelism to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using `the provided script <https://docs.vllm.ai/en/latest/getting_started/examples/save_sharded_state.html>`_ . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
Hangs loading a model from disk
----------------------------------------
If the model is large, it can take a long time to load it from disk. Pay attention to where you store the model. Some clusters have shared filesystems across nodes, e.g. a distributed filesystem or a network filesystem, which can be slow.
It'd be better to store the model in a local disk. Additionally, have a look at the CPU memory usage, when the model is too large it might take a lot of CPU memory, slowing down the operating system because it needs to frequently swap between disk and memory.

If you have already taken care of the above issues, but the vLLM instance still hangs, with CPU and GPU utilization at near zero, it is likely that the vLLM instance is stuck somewhere. Here are some tips to help debug the issue:
Model is too large
----------------------------------------
If the model is too large to fit in a single GPU, you might want to `consider tensor parallelism <https://docs.vllm.ai/en/latest/serving/distributed_serving.html#distributed-inference-and-serving>`_ to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using `this example <https://docs.vllm.ai/en/latest/getting_started/examples/save_sharded_state.html>`_ . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.

- Set the environment variable ``export VLLM_LOGGING_LEVEL=DEBUG`` to turn on more logging.
- Set the environment variable ``export CUDA_LAUNCH_BLOCKING=1`` to know exactly which CUDA kernel is causing the trouble.
- Set the environment variable ``export NCCL_DEBUG=TRACE`` to turn on more logging for NCCL.
- Set the environment variable ``export VLLM_TRACE_FUNCTION=1``. All the function calls in vLLM will be recorded. Inspect these log files, and tell which function crashes or hangs.
Enable more logging
----------------------------------------
If other strategies don't solve the problem, it's likely that the vLLM instance is stuck somewhere. You can use the following environment variables to help debug the issue:

With more logging, hopefully you can find the root cause of the issue.
- ``export VLLM_LOGGING_LEVEL=DEBUG`` to turn on more logging.
- ``export CUDA_LAUNCH_BLOCKING=1`` to identify which CUDA kernel is causing the problem.
- ``export NCCL_DEBUG=TRACE`` to turn on more logging for NCCL.
- ``export VLLM_TRACE_FUNCTION=1`` to record all function calls for inspection in the log files to tell which function crashes or hangs.

If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the :class:`~vllm.LLM` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error.
Incorrect network setup
----------------------------------------
The vLLM instance cannot get the correct IP address if you have a complicated network config. You can find a log such as ``DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl`` and the IP address should be the correct one.
If it's not, override the IP address using the environment variable ``export VLLM_HOST_IP=<your_ip_address>``.

Here are some common issues that can cause hangs:
You might also need to set ``export NCCL_SOCKET_IFNAME=<your_network_interface>`` and ``export GLOO_SOCKET_IFNAME=<your_network_interface>`` to specify the network interface for the IP address.

- **Incorrect network setup**: The vLLM instance cannot get the correct IP address if you have complicated network config. You can find the log such as ``DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl``. The IP address should be the correct one. If not, override the IP address by setting the environment variable ``export VLLM_HOST_IP=your_ip_address``. You might also need to set ``export NCCL_SOCKET_IFNAME=your_network_interface`` and ``export GLOO_SOCKET_IFNAME=your_network_interface`` to specify the network interface for the IP address.
- **Incorrect hardware/driver**: GPU/CPU communication cannot be established. You can run the following sanity check script to see if the GPU/CPU communication is working correctly.
Error near ``self.graph.replay()``
----------------------------------------
If vLLM crashes and the error trace captures it somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a CUDA error inside CUDAGraph.
To identify the particular CUDA operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the :class:`~vllm.LLM` class to disable the CUDAGraph optimization and isolate the exact CUDA operation that causes the error.

Incorrect hardware/driver
----------------------------------------
If GPU/CPU communication cannot be established, you can use the following Python script and follow the instructions below to confirm whether the GPU/CPU communication is working correctly.

.. code-block:: python
Expand Down Expand Up @@ -84,33 +105,29 @@ Here are some common issues that can cause hangs:
dist.destroy_process_group(gloo_group)
dist.destroy_process_group()
.. tip::
If you are testing with a single node, adjust ``--nproc-per-node`` to the number of GPUs you want to use:

Save the script as ``test.py``.

If you are testing in a single-node, run it with ``NCCL_DEBUG=TRACE torchrun --nproc-per-node=8 test.py``, adjust ``--nproc-per-node`` to the number of GPUs you want to use.

If you are testing with multi-nodes, run it with ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR test.py``. Adjust ``--nproc-per-node`` and ``--nnodes`` according to your setup. Make sure ``MASTER_ADDR``:

- is the correct IP address of the master node
- is reachable from all nodes
- is set before running the script.
.. code-block:: shell
If the script runs successfully, you should see the message ``sanity check is successful!``.
NCCL_DEBUG=TRACE torchrun --nproc-per-node=<number-of-GPUs> test.py
Note that multi-node environment is more complicated than single-node. If you see errors such as ``torch.distributed.DistNetworkError``, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments:
If you are testing with multi-nodes, adjust ``--nproc-per-node`` and ``--nnodes`` according to your setup and set ``MASTER_ADDR`` to the correct IP address of the master node, reachable from all nodes. Then, run:

- In the first node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py``.
- In the second node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py``.
.. code-block:: shell
NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR test.py
Adjust ``--nproc-per-node``, ``--nnodes``, and ``--node-rank`` according to your setup. The difference is that you need to execute different commands (with different ``--node-rank``) on different nodes.
If the script runs successfully, you should see the message ``sanity check is successful!``.

If the problem persists, feel free to `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_, with a detailed description of the issue, your environment, and the logs.
.. note::

Some known issues:
A multi-node environment is more complicated than a single-node one. If you see errors such as ``torch.distributed.DistNetworkError``, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments:

- In ``v0.5.2``, ``v0.5.3``, and ``v0.5.3.post1``, there is a bug caused by `zmq <https://github.com/zeromq/pyzmq/issues/2000>`_ , which can cause hangs at a low probability (once in about 20 times, depending on the machine configuration). The solution is to upgrade to the latest version of ``vllm`` to include the `fix <https://github.com/vllm-project/vllm/pull/6759>`_ .
- In the first node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py``.
- In the second node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py``.

.. warning::
Adjust ``--nproc-per-node``, ``--nnodes``, and ``--node-rank`` according to your setup, being sure to execute different commands (with different ``--node-rank``) on different nodes.

After you find the root cause and solve the issue, remember to turn off all the debugging environment variables defined above, or simply start a new shell to avoid being affected by the debugging settings. If you don't do this, the system might be slow because many debugging functionalities are turned on.
Known Issues
----------------------------------------
- In ``v0.5.2``, ``v0.5.3``, and ``v0.5.3.post1``, there is a bug caused by `zmq <https://github.com/zeromq/pyzmq/issues/2000>`_ , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of ``vllm`` to include the `fix <https://github.com/vllm-project/vllm/pull/6759>`_.

0 comments on commit 7326a7c

Please sign in to comment.