Skip to content

V1 for fp4 #584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: ROCm-7.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions docker/Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ ARG BASE_IMAGE=compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compu
# ARG HIPBLASLT_BRANCH="aa0bda7b"
# ARG HIPBLAS_COMMON_BRANCH="9b80ba8e"
# ARG LEGACY_HIPBLASLT_OPTION=
ARG TRITON_BRANCH="916969a"
ARG TRITON_BRANCH="0e78e54a"
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
# ARG PYTORCH_BRANCH="37f92bb"
# ARG PYTORCH_VISION_BRANCH="v0.21.0"
# ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
# ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG PYTORCH_BRANCH="37f92bb"
ARG PYTORCH_VISION_BRANCH="95f10a4"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="8ede036"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="d765e80"
ARG AITER_BRANCH="498ff21"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base
Expand All @@ -20,7 +20,7 @@ ENV ROCM_PATH=/opt/rocm
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
ENV AITER_ROCM_ARCH=gfx950
ENV AITER_ROCM_ARCH=gfx942;gfx950

ARG PYTHON_VERSION=3.12

Expand Down Expand Up @@ -85,30 +85,30 @@ RUN cd /opt/rocm/share/amd_smi \
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install

FROM base AS build_pytorch
# ARG PYTORCH_BRANCH
# ARG PYTORCH_VISION_BRANCH
# ARG PYTORCH_REPO
# ARG PYTORCH_VISION_REPO
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
# RUN git clone ${PYTORCH_REPO} pytorch
# RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
# pip install -r requirements.txt && git submodule update --init --recursive \
# && python3 tools/amd_build/build_amd.py \
# && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
# && pip install dist/*.whl
# RUN git clone ${PYTORCH_VISION_REPO} vision
# RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
# && python3 setup.py bdist_wheel --dist-dir=dist \
# && pip install dist/*.whl
RUN git clone ${PYTORCH_REPO} pytorch
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
pip install -r requirements.txt && git submodule update --init --recursive \
&& python3 tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${PYTORCH_VISION_REPO} vision
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
&& python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${FA_REPO}
RUN cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist
# RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
# && cp /app/vision/dist/*.whl /app/install \
# && cp /app/flash-attention/dist/*.whl /app/install
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install \
&& cp /app/flash-attention/dist/*.whl /app/install
RUN mkdir -p /app/install && cp /app/flash-attention/dist/*.whl /app/install

FROM base AS build_aiter
Expand All @@ -132,8 +132,8 @@ RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
cp /install/*.whl /app/debs
# RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
# cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
cp /install/*.whl /app/debs

Expand All @@ -147,8 +147,8 @@ RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
pip install /install/*.whl
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
pip install /install/*.whl
# RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
# pip install /install/*.whl
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl
RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
pip install /install/*.whl

Expand All @@ -172,10 +172,10 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
# && echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \
&& echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
&& echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
# && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
# && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
# && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
# && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
&& echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
82 changes: 25 additions & 57 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,8 @@ def __init__(

self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN and \
envs.VLLM_ROCM_USE_AITER
if self.use_triton_flash_attn:
if logits_soft_cap is not None:
raise ValueError(
Expand All @@ -571,16 +572,11 @@ def __init__(
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")

if self.kv_cache_dtype in ["int8", "fp8_e4m3"]:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.triton_attn_func = triton_attention
else:
from aiter.ops.triton.mha import flash_attn_varlen_func
from aiter.ops.triton.mha import (
mha_set_use_int64_strides as set_triton_fa_strides)
set_triton_fa_strides(True)
self.triton_attn_func = flash_attn_varlen_func
from aiter.ops.triton.mha import flash_attn_varlen_func
from aiter.ops.triton.mha import (mha_set_use_int64_strides as
set_triton_fa_strides)
set_triton_fa_strides(True)
self.triton_attn_func = flash_attn_varlen_func
logger.debug("Using Triton FA in ROCmBackend")
if self.sliding_window != (-1, -1):
logger.warning("ROCm Triton FA does not currently support "
Expand Down Expand Up @@ -803,52 +799,24 @@ def forward(
query.dtype,
seq_lens,
make_attn_mask=causal_mask) # type: ignore
if self.kv_cache_dtype in ["int8", "fp8_e4m3"]:
use_fp8_scales = (layer._q_scale is not None
and layer._k_scale is not None
and layer._v_scale is not None
and layer._prob_scale is not None and
envs.VLLM_USE_ROCM_FP8_FLASH_ATTN)
full_scales = (layer._q_scale.item(),
layer._k_scale.item(),
layer._v_scale.item(),
layer._prob_scale.item()
) if use_fp8_scales else None
self.triton_attn_func(
query,
key,
value,
output[:num_prefill_tokens],
query_seq_start_loc,
key_seq_start_loc,
query_max_seq_len,
key_max_seq_len,
causal_mask,
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
full_scales,
layer._out_scale,
)
else:
output[:num_prefill_tokens] = self.triton_attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=query_seq_start_loc,
cu_seqlens_k=key_seq_start_loc,
max_seqlen_q=query_max_seq_len,
max_seqlen_k=key_max_seq_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=causal_mask,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
deterministic=False,
return_lse=False,
return_attn_probs=False,
block_table=None,
)
output[:num_prefill_tokens] = self.triton_attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=query_seq_start_loc,
cu_seqlens_k=key_seq_start_loc,
max_seqlen_q=query_max_seq_len,
max_seqlen_k=key_max_seq_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=causal_mask,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
deterministic=False,
return_lse=False,
return_attn_probs=False,
block_table=None,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
Expand Down
Loading
Loading