Skip to content

Commit

Permalink
Write info to local json
Browse files Browse the repository at this point in the history
[ROCm] Fix build problem resulted from previous commit related to FP8 kv-cache support  (vllm-project#2790)

Add documentation on how to do incremental builds (vllm-project#2796)

[Ray] Integration compiled DAG off by default (vllm-project#2471)

Disable custom all reduce by default (vllm-project#2808)

add usage context

removed usage_context from Engine_args

Move IO to another process

added http request

[ROCm] support Radeon™ 7900 series (gfx1100) without using flash-attention (vllm-project#2768)

Add documentation section about LoRA (vllm-project#2834)

Refactor 2 awq gemm kernels into m16nXk32 (vllm-project#2723)

Co-authored-by: Chunan Zeng <chunanzeng@Chunans-Air.attlocal.net>

Added additional arg for from_engine_args

comments
  • Loading branch information
yhu422 committed Feb 13, 2024
1 parent fe6d09a commit 8a2f18a
Show file tree
Hide file tree
Showing 20 changed files with 414 additions and 329 deletions.
16 changes: 13 additions & 3 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
ARG FA_BRANCH="3d2b6f5"
RUN echo "FA_BRANCH is $FA_BRANCH"

# whether to build flash-attention
# if 0, will not build flash attention
# this is useful for gfx target where flash-attention is not supported
# In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1"

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y

Expand Down Expand Up @@ -50,7 +56,8 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:

# Install ROCm flash-attention
RUN mkdir libs \
RUN if [ "$BUILD_FA" == "1" ]; then \
mkdir libs \
&& cd libs \
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
&& cd flash-attention \
Expand All @@ -60,7 +67,8 @@ RUN mkdir libs \
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
&& python3 setup.py install \
&& cd ..
&& cd ..; \
fi

COPY ./ /app/vllm

Expand All @@ -75,7 +83,9 @@ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
&& bash patch_xformers.rocm.sh \
&& if [ "$BUILD_FA" == "1" ]; then \
bash patch_xformers.rocm.sh; fi \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \
&& cd ..

Expand Down
366 changes: 72 additions & 294 deletions csrc/quantization/awq/gemm_kernels.cu

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion docs/source/getting_started/amd-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Requirements

* OS: Linux
* Python: 3.8 -- 3.11
* GPU: MI200s (gfx90a), MI300 (gfx942)
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
* Pytorch 2.0.1/2.1.1/2.2
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)

Expand Down Expand Up @@ -105,6 +105,7 @@ The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.

Their values can be passed in when running ``docker build`` with ``--build-arg`` options.

Expand Down
10 changes: 10 additions & 0 deletions docs/source/getting_started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,13 @@ You can also build and install vLLM from source:
$ # Use `--ipc=host` to make sure the shared memory is large enough.
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3
.. note::
If you are developing the C++ backend of vLLM, consider building vLLM with

.. code-block:: console
$ python setup.py develop
since it will give you incremental builds. The downside is that this method
is `deprecated by setuptools <https://github.com/pypa/setuptools/issues/917>`_.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Documentation
models/supported_models
models/adding_model
models/engine_args
models/lora

.. toctree::
:maxdepth: 1
Expand Down
52 changes: 52 additions & 0 deletions docs/source/models/lora.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
.. _lora:

Using LoRA adapters
===================

This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
them locally with

.. code-block:: python
from huggingface_hub import snapshot_download
sql_lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
Then we instantiate the base model and pass in the ``enable_lora=True`` flag:

.. code-block:: python
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_lora=True)
We can now submit the prompts and call ``llm.generate`` with the ``lora_request`` parameter. The first parameter
of ``LoRARequest`` is a human identifiable name, the second parameter is a globally unique ID for the adapter and
the third parameter is the path to the LoRA adapter.

.. code-block:: python
sampling_params = SamplingParams(
temperature=0,
max_tokens=256,
stop=["[/assistant]"]
)
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
]
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest("sql_adapter", 1, sql_lora_path)
)
Check out `examples/multilora_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py>`_
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ ray >= 2.9
sentencepiece # Required for LLaMA tokenizer.
numpy
torch == 2.1.2
cloud-detect
transformers >= 4.37.0 # Required for Qwen2
xformers == 0.0.23.post1 # Required for CUDA 12.1.
fastapi
Expand Down
15 changes: 15 additions & 0 deletions rocm_patch/rocm_bf16.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
--- amd_hip_bf16.h 2024-02-06 18:28:58.268699142 +0000
+++ amd_hip_bf16.h.new 2024-02-06 18:28:31.988647133 +0000
@@ -90,10 +90,10 @@
#include "math_fwd.h" // ocml device functions

#if defined(__HIPCC_RTC__)
-#define __HOST_DEVICE__ __device__
+#define __HOST_DEVICE__ __device__ static
#else
#include <climits>
-#define __HOST_DEVICE__ __host__ __device__
+#define __HOST_DEVICE__ __host__ __device__ static inline
#endif

// Since we are using unsigned short to represent data in bfloat16, it can be of different sizes on
7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@

ROOT_DIR = os.path.dirname(__file__)

# If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds.
# The downside is that this method is deprecated, see
# https://github.com/pypa/setuptools/issues/917

MAIN_CUDA_VERSION = "12.1"

# Supported NVIDIA GPU architectures.
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942", "gfx1100"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)


Expand Down
26 changes: 18 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,16 +388,26 @@ def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is not supported yet.")
if is_hip():
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
elif self.pipeline_parallel_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")

# FIXME(woosuk): Fix the stability issues and re-enable the custom
# all-reduce kernel.
if not self.disable_custom_all_reduce and self.world_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
elif self.pipeline_parallel_size > 1:
self.disable_custom_all_reduce = True
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")
"Custom all-reduce kernels are temporarily disabled due to "
"stability issues. We will re-enable them once the issues are "
"resolved.")


class SchedulerConfig:
Expand Down
1 change: 0 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)


@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
Expand Down
9 changes: 6 additions & 3 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams

from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)


Expand Down Expand Up @@ -613,7 +613,8 @@ async def get_model_config(self) -> ModelConfig:
@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.UNKNOWN_CONTEXT) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
Expand All @@ -629,7 +630,9 @@ def from_engine_args(cls,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
start_engine_loop=start_engine_loop,
usage_context=usage_context
)
return engine

async def do_log_stats(self) -> None:
Expand Down
Loading

0 comments on commit 8a2f18a

Please sign in to comment.