Skip to content

[Bugfix][CI/Build][Hardware][AMD] Fix AMD tests, add HF cache, update CK FA, add partially supported model notes #6543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,18 @@ trap remove_docker_container EXIT

echo "--- Running container"

HF_CACHE="$(realpath ~)/huggingface"
mkdir -p ${HF_CACHE}
HF_MOUNT="/root/.cache/huggingface"

docker run \
--device /dev/kfd --device /dev/dri \
--network host \
--shm-size=16gb \
--rm \
-e HF_TOKEN \
-v ${HF_CACHE}:${HF_MOUNT} \
-e HF_HOME=${HF_MOUNT} \
--name ${container_name} \
${image_name} \
/bin/bash -c "${@}"
Expand Down
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ steps:
mirror_hardwares: [amd]
fast_check: true
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true
- pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
Expand Down
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1")
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")

#
# Try to find python package with an executable that exactly matches
Expand Down Expand Up @@ -101,7 +101,7 @@ elseif(HIP_FOUND)
# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif()
else()
Expand Down
60 changes: 35 additions & 25 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
# Default ROCm ARCHes to build vLLM for.
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"

# Whether to build CK-based flash-attention
# If 0, will not build flash attention
# This is useful for gfx target where flash-attention is not supported
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
# Triton FA is used by default on ROCm now so this is unnecessary.
# Whether to install CK-based flash-attention
# If 0, will not install flash-attention
ARG BUILD_FA="1"
# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL`
# If this succeeds, we use the downloaded wheel and skip building flash-attention.
# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the
# architectures specified in `FA_GFX_ARCHS`
ARG TRY_FA_WHEEL="1"
ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
ARG FA_BRANCH="ae7928c"
ARG FA_BRANCH="23a2b1c2"

# Whether to build triton on rocm
ARG BUILD_TRITON="1"
ARG TRITON_BRANCH="0ef1848"
ARG TRITON_BRANCH="e0fc12c"

### Base image build stage
FROM $BASE_IMAGE AS base
Expand Down Expand Up @@ -43,15 +46,15 @@ RUN apt-get update && apt-get install -y \
ARG APP_MOUNT=/vllm-workspace
WORKDIR ${APP_MOUNT}

RUN pip install --upgrade pip
RUN python3 -m pip install --upgrade pip
# Remove sccache so it doesn't interfere with ccache
# TODO: implement sccache support across components
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
# Install torch == 2.5.0 on ROCm
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.1"*) \
pip uninstall -y torch torchaudio torchvision \
&& pip install --no-cache-dir --pre \
python3 -m pip uninstall -y torch torchaudio torchvision \
&& python3 -m pip install --no-cache-dir --pre \
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
torchvision==0.20.0.dev20240710 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
Expand All @@ -70,24 +73,31 @@ ENV CCACHE_DIR=/root/.cache/ccache
FROM base AS build_amdsmi
# Build amdsmi wheel always
RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=/install
&& python3 -m pip wheel . --wheel-dir=/install


### Flash-Attention wheel build stage
FROM base AS build_fa
ARG BUILD_FA
ARG TRY_FA_WHEEL
ARG FA_WHEEL_URL
ARG FA_GFX_ARCHS
ARG FA_BRANCH
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \
if [ "$BUILD_FA" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& git checkout "${FA_BRANCH}" \
&& git submodule update --init \
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \
# If a suitable wheel exists, we download it instead of building FA
mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \
else \
mkdir -p libs \
&& cd libs \
&& git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& git checkout "${FA_BRANCH}" \
&& git submodule update --init \
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
fi; \
# Create an empty directory otherwise as later build stages expect one
else mkdir -p /install; \
fi
Expand Down Expand Up @@ -126,7 +136,7 @@ RUN case "$(which python3)" in \

# Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \
pip install --upgrade numba scipy huggingface-hub[cli]
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]

# Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
Expand All @@ -137,7 +147,7 @@ ENV TOKENIZERS_PARALLELISM=false

RUN --mount=type=cache,target=${CCACHE_DIR} \
--mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
python3 -m pip install -Ur requirements-rocm.txt \
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.1"*) \
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
Expand All @@ -153,27 +163,27 @@ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
mkdir -p libs \
&& cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y amdsmi;
&& python3 -m pip uninstall -y amdsmi;

# Copy triton wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y triton; fi
&& python3 -m pip uninstall -y triton; fi

# Copy flash-attn wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& pip uninstall -y flash-attn; fi
&& python3 -m pip uninstall -y flash-attn; fi

# Install wheels that were built to the final image
RUN --mount=type=cache,target=/root/.cache/pip \
if ls libs/*.whl; then \
pip install libs/*.whl; fi
python3 -m pip install libs/*.whl; fi

CMD ["/bin/bash"]
7 changes: 4 additions & 3 deletions docs/source/getting_started/amd-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTor

Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton <https://github.com/ROCm/triton/blob/triton-mlir/README.md>`_

2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm>`_
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/ck_tile>`_

Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support>`_
Alternatively, wheels intended for vLLM use can be accessed under the releases.

.. note::
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)

3. Build vLLM.
Expand All @@ -110,5 +110,6 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl
.. tip::

- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
- To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.
4 changes: 4 additions & 0 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,9 @@
-r requirements-common.txt

# Dependencies for AMD GPUs
awscli
boto3
botocore
ray >= 2.10.0
peft
pytest-asyncio
9 changes: 7 additions & 2 deletions tests/basic_correctness/test_cpu_offload.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from vllm.utils import is_hip

from ..utils import compare_two_settings


def test_cpu_offload():
compare_two_settings("meta-llama/Llama-2-7b-hf", [],
["--cpu-offload-gb", "4"])
compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t",
[], ["--cpu-offload-gb", "1"])
if not is_hip():
# compressed-tensors quantization is currently not supported in ROCm.
compare_two_settings(
"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [],
["--cpu-offload-gb", "1"])
18 changes: 17 additions & 1 deletion tests/models/test_paligemma.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoTokenizer

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_hip

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
Expand All @@ -22,6 +24,12 @@

models = ["google/paligemma-3b-mix-224"]

# ROCm Triton FA can run into compilation issues with these models due to,
# excessive use of shared memory. Use other backends in the meantime.
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
Expand Down Expand Up @@ -130,7 +138,15 @@ def run_test(
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["float", "half"])
@pytest.mark.parametrize("dtype", [
pytest.param(
"float",
marks=pytest.mark.skipif(
is_hip(),
reason=
"ROCm FA does not yet fully support 32-bit precision on PaliGemma")
), "half"
])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
Expand Down
9 changes: 8 additions & 1 deletion tests/models/test_phi3v.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
from typing import List, Optional, Tuple, Type

Expand All @@ -6,7 +7,7 @@

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from vllm.utils import is_cpu, is_hip

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
Expand Down Expand Up @@ -47,6 +48,12 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
if is_cpu():
target_dtype = "bfloat16"

# ROCm Triton FA can run into shared memory issues with these models,
# use other backends in the meantime
# FIXME (mattwong, gshtrasb, hongxiayan)
if is_hip():
os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"


def run_test(
hf_runner: Type[HfRunner],
Expand Down
8 changes: 8 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ def __init__(
triton_attention)
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
"sliding window attention. If using half "
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
Expand Down Expand Up @@ -434,6 +440,8 @@ def forward(
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)

# common code for prefill
Expand Down
17 changes: 14 additions & 3 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,24 @@

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
_ROCM_SWA_REASON,
"MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
_ROCM_SWA_REASON,
"MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


Expand Down
9 changes: 8 additions & 1 deletion vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import torch

from vllm import _custom_ops as ops
from vllm.attention.backends.flash_attn import FlashAttentionMetadata

try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)

from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
Expand Down
Loading