Skip to content

Commit

Permalink
[Bugfix][CI/Build][Hardware][AMD] Fix AMD tests, add HF cache, update…
Browse files Browse the repository at this point in the history
… CK FA, add partially supported model notes (vllm-project#6543)

Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
mawong-amd authored and Alvant committed Oct 26, 2024
1 parent 66aee67 commit 594c454
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 39 deletions.
7 changes: 7 additions & 0 deletions .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,18 @@ trap remove_docker_container EXIT

echo "--- Running container"

HF_CACHE="$(realpath ~)/huggingface"
mkdir -p ${HF_CACHE}
HF_MOUNT="/root/.cache/huggingface"

docker run \
--device /dev/kfd --device /dev/dri \
--network host \
--shm-size=16gb \
--rm \
-e HF_TOKEN \
-v ${HF_CACHE}:${HF_MOUNT} \
-e HF_HOME=${HF_MOUNT} \
--name ${container_name} \
${image_name} \
/bin/bash -c "${@}"
Expand Down
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ steps:
mirror_hardwares: [amd]
fast_check: true
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true
- pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
Expand Down
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1")
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")

#
# Try to find python package with an executable that exactly matches
Expand Down Expand Up @@ -101,7 +101,7 @@ elseif(HIP_FOUND)
# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif()
else()
Expand Down
60 changes: 35 additions & 25 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
# Default ROCm ARCHes to build vLLM for.
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"

# Whether to build CK-based flash-attention
# If 0, will not build flash attention
# This is useful for gfx target where flash-attention is not supported
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
# Triton FA is used by default on ROCm now so this is unnecessary.
# Whether to install CK-based flash-attention
# If 0, will not install flash-attention
ARG BUILD_FA="1"
# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL`
# If this succeeds, we use the downloaded wheel and skip building flash-attention.
# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the
# architectures specified in `FA_GFX_ARCHS`
ARG TRY_FA_WHEEL="1"
ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
ARG FA_BRANCH="ae7928c"
ARG FA_BRANCH="23a2b1c2"

# Whether to build triton on rocm
ARG BUILD_TRITON="1"
ARG TRITON_BRANCH="0ef1848"
ARG TRITON_BRANCH="e0fc12c"

### Base image build stage
FROM $BASE_IMAGE AS base
Expand Down Expand Up @@ -43,15 +46,15 @@ RUN apt-get update && apt-get install -y \
ARG APP_MOUNT=/vllm-workspace
WORKDIR ${APP_MOUNT}

RUN pip install --upgrade pip
RUN python3 -m pip install --upgrade pip
# Remove sccache so it doesn't interfere with ccache
# TODO: implement sccache support across components
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
# Install torch == 2.5.0 on ROCm
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.1"*) \
pip uninstall -y torch torchaudio torchvision \
&& pip install --no-cache-dir --pre \
python3 -m pip uninstall -y torch torchaudio torchvision \
&& python3 -m pip install --no-cache-dir --pre \
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
torchvision==0.20.0.dev20240710 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
Expand All @@ -70,24 +73,31 @@ ENV CCACHE_DIR=/root/.cache/ccache
FROM base AS build_amdsmi
# Build amdsmi wheel always
RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=/install
&& python3 -m pip wheel . --wheel-dir=/install


### Flash-Attention wheel build stage
FROM base AS build_fa
ARG BUILD_FA
ARG TRY_FA_WHEEL
ARG FA_WHEEL_URL
ARG FA_GFX_ARCHS
ARG FA_BRANCH
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \
if [ "$BUILD_FA" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& git checkout "${FA_BRANCH}" \
&& git submodule update --init \
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \
# If a suitable wheel exists, we download it instead of building FA
mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \
else \
mkdir -p libs \
&& cd libs \
&& git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& git checkout "${FA_BRANCH}" \
&& git submodule update --init \
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
fi; \
# Create an empty directory otherwise as later build stages expect one
else mkdir -p /install; \
fi
Expand Down Expand Up @@ -126,7 +136,7 @@ RUN case "$(which python3)" in \

# Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --upgrade numba scipy huggingface-hub[cli]
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]

# Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
Expand All @@ -137,7 +147,7 @@ ENV TOKENIZERS_PARALLELISM=false

RUN --mount=type=cache,target=${CCACHE_DIR} \
--mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
python3 -m pip install -Ur requirements-rocm.txt \
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.1"*) \
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
Expand All @@ -153,27 +163,27 @@ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
mkdir -p libs \
&& cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y amdsmi;
&& python3 -m pip uninstall -y amdsmi;

# Copy triton wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y triton; fi
&& python3 -m pip uninstall -y triton; fi

# Copy flash-attn wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y flash-attn; fi
&& python3 -m pip uninstall -y flash-attn; fi

# Install wheels that were built to the final image
RUN --mount=type=cache,target=/root/.cache/pip \
if ls libs/*.whl; then \
pip install libs/*.whl; fi
python3 -m pip install libs/*.whl; fi

CMD ["/bin/bash"]
7 changes: 4 additions & 3 deletions docs/source/getting_started/amd-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTor

Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton <https://github.com/ROCm/triton/blob/triton-mlir/README.md>`_

2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm>`_
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/ck_tile>`_

Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support>`_
Alternatively, wheels intended for vLLM use can be accessed under the releases.

.. note::
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)

3. Build vLLM.
Expand All @@ -110,5 +110,6 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl
.. tip::

- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
- To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.
4 changes: 4 additions & 0 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@
-r requirements-common.txt

# Dependencies for AMD GPUs
awscli
boto3
botocore
ray >= 2.10.0
peft
pytest-asyncio
9 changes: 7 additions & 2 deletions tests/basic_correctness/test_cpu_offload.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from vllm.utils import is_hip

from ..utils import compare_two_settings


def test_cpu_offload():
compare_two_settings("meta-llama/Llama-2-7b-hf", [],
["--cpu-offload-gb", "4"])
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
[], ["--cpu-offload-gb", "1"])
if not is_hip():
# compressed-tensors quantization is currently not supported in ROCm.
compare_two_settings(
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [],
["--cpu-offload-gb", "1"])
18 changes: 17 additions & 1 deletion tests/models/test_paligemma.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoTokenizer

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_hip

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
Expand All @@ -22,6 +24,12 @@

models = ["google/paligemma-3b-mix-224"]

# ROCm Triton FA can run into compilation issues with these models due to,
# excessive use of shared memory. Use other backends in the meantime.
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
Expand Down Expand Up @@ -130,7 +138,15 @@ def run_test(
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["float", "half"])
@pytest.mark.parametrize("dtype", [
pytest.param(
"float",
marks=pytest.mark.skipif(
is_hip(),
reason=
"ROCm FA does not yet fully support 32-bit precision on PaliGemma")
), "half"
])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
Expand Down
9 changes: 8 additions & 1 deletion tests/models/test_phi3v.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
from typing import List, Optional, Tuple, Type

Expand All @@ -6,7 +7,7 @@

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from vllm.utils import is_cpu, is_hip

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
Expand Down Expand Up @@ -47,6 +48,12 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
if is_cpu():
target_dtype = "bfloat16"

# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"


def run_test(
hf_runner: Type[HfRunner],
Expand Down
8 changes: 8 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ def __init__(
triton_attention)
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
"sliding window attention. If using half "
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
Expand Down Expand Up @@ -434,6 +440,8 @@ def forward(
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)

# common code for prefill
Expand Down
17 changes: 14 additions & 3 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,24 @@

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
_ROCM_SWA_REASON,
"MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
_ROCM_SWA_REASON,
"MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


Expand Down
9 changes: 8 additions & 1 deletion vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import torch

from vllm import _custom_ops as ops
from vllm.attention.backends.flash_attn import FlashAttentionMetadata

try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
Expand Down

0 comments on commit 594c454

Please sign in to comment.