From 7d200aac5576f973ad456a93be011648af3132a4 Mon Sep 17 00:00:00 2001 From: Matt Wong <156021403+mawong-amd@users.noreply.github.com> Date: Tue, 25 Jun 2024 17:56:15 -0500 Subject: [PATCH] [Hardware][AMD][CI/Build][Doc] Upgrade to ROCm 6.1, Dockerfile improvements, test fixes (#5422) --- CMakeLists.txt | 20 +- Dockerfile.rocm | 209 ++++++++++++------ cmake/utils.cmake | 20 +- .../getting_started/amd-installation.rst | 6 +- tests/async_engine/test_openapi_server_ray.py | 4 +- tests/distributed/test_utils.py | 17 +- tests/entrypoints/test_openai_embedding.py | 4 +- tests/entrypoints/test_openai_server.py | 4 +- tests/entrypoints/test_openai_vision.py | 4 +- tests/utils.py | 38 +++- vllm/config.py | 10 +- .../custom_all_reduce_utils.py | 11 +- vllm/executor/multiproc_gpu_executor.py | 8 +- vllm/utils.py | 16 +- vllm/worker/worker_base.py | 10 +- 15 files changed, 259 insertions(+), 122 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index aa15b632cdd3b..801429096eaab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,8 +32,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # versions are derived from Dockerfile.rocm # set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0") -set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1") -set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1") +set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0") # # Try to find python package with an executable that exactly matches @@ -98,18 +97,11 @@ elseif(HIP_FOUND) # .hip extension automatically, HIP must be enabled explicitly. enable_language(HIP) - # ROCm 5.x - if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND - NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X}) - message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} " - "expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.") - endif() - - # ROCm 6.x - if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND - NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X}) - message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} " - "expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.") + # ROCm 5.X and 6.X + if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND + NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM}) + message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} " + "expected for ROCm build, saw ${Torch_VERSION} instead.") endif() else() message(FATAL_ERROR "Can't find CUDA or HIP installation.") diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 6bda696859c8b..652f04adf8959 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,34 +1,35 @@ -# default base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" - -FROM $BASE_IMAGE - -ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" - -RUN echo "Base image is $BASE_IMAGE" - -ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \ - ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" - +# Default ROCm 6.1 base image +ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" + +# Tested and supported base rocm/pytorch images +ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \ + ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \ + ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" + +# Default ROCm ARCHes to build vLLM for. +ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" + +# Whether to build CK-based flash-attention +# If 0, will not build flash attention +# This is useful for gfx target where flash-attention is not supported +# (i.e. those that do not appear in `FA_GFX_ARCHS`) +# Triton FA is used by default on ROCm now so this is unnecessary. +ARG BUILD_FA="1" ARG FA_GFX_ARCHS="gfx90a;gfx942" -RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" - ARG FA_BRANCH="ae7928c" -RUN echo "FA_BRANCH is $FA_BRANCH" -# whether to build flash-attention -# if 0, will not build flash attention -# this is useful for gfx target where flash-attention is not supported -# In that case, we need to use the python reference attention implementation in vllm -ARG BUILD_FA="1" - -# whether to build triton on rocm +# Whether to build triton on rocm ARG BUILD_TRITON="1" +ARG TRITON_BRANCH="0ef1848" -# Install some basic utilities -RUN apt-get update && apt-get install python3 python3-pip -y +### Base image build stage +FROM $BASE_IMAGE AS base + +# Import arg(s) defined before this build stage +ARG PYTORCH_ROCM_ARCH # Install some basic utilities +RUN apt-get update && apt-get install python3 python3-pip -y RUN apt-get update && apt-get install -y \ curl \ ca-certificates \ @@ -39,79 +40,159 @@ RUN apt-get update && apt-get install -y \ build-essential \ wget \ unzip \ - nvidia-cuda-toolkit \ tmux \ ccache \ && rm -rf /var/lib/apt/lists/* -### Mount Point ### -# When launching the container, mount the code directory to /app +# When launching the container, mount the code directory to /vllm-workspace ARG APP_MOUNT=/vllm-workspace -VOLUME [ ${APP_MOUNT} ] WORKDIR ${APP_MOUNT} -RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas +RUN pip install --upgrade pip +# Remove sccache so it doesn't interfere with ccache +# TODO: implement sccache support across components +RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)" +# Install torch == 2.4.0 on ROCm +RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ + *"rocm-5.7"*) \ + pip uninstall -y torch \ + && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \ + --index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \ + *"rocm-6.0"*) \ + pip uninstall -y torch \ + && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \ + --index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \ + *"rocm-6.1"*) \ + pip uninstall -y torch \ + && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \ + --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ + *) ;; esac ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: -# Install ROCm flash-attention -RUN if [ "$BUILD_FA" = "1" ]; then \ - mkdir libs \ +ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} +ENV CCACHE_DIR=/root/.cache/ccache + + +### AMD-SMI build stage +FROM base AS build_amdsmi +# Build amdsmi wheel always +RUN cd /opt/rocm/share/amd_smi \ + && pip wheel . --wheel-dir=/install + + +### Flash-Attention wheel build stage +FROM base AS build_fa +ARG BUILD_FA +ARG FA_GFX_ARCHS +ARG FA_BRANCH +# Build ROCm flash-attention wheel if `BUILD_FA = 1` +RUN --mount=type=cache,target=${CCACHE_DIR} \ + if [ "$BUILD_FA" = "1" ]; then \ + mkdir -p libs \ && cd libs \ && git clone https://github.com/ROCm/flash-attention.git \ && cd flash-attention \ - && git checkout ${FA_BRANCH} \ + && git checkout "${FA_BRANCH}" \ && git submodule update --init \ - && export GPU_ARCHS=${FA_GFX_ARCHS} \ - && if [ "$BASE_IMAGE" = "$ROCm_5_7_BASE" ]; then \ - patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \ - && python3 setup.py install \ - && cd ..; \ + && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ + *"rocm-5.7"*) \ + export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \ + && patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \ + *) ;; esac \ + && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ + # Create an empty directory otherwise as later build stages expect one + else mkdir -p /install; \ fi -# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. -# Manually removed it so that later steps of numpy upgrade can continue -RUN if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \ - rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi -# build triton -RUN if [ "$BUILD_TRITON" = "1" ]; then \ +### Triton wheel build stage +FROM base AS build_triton +ARG BUILD_TRITON +ARG TRITON_BRANCH +# Build triton wheel if `BUILD_TRITON = 1` +RUN --mount=type=cache,target=${CCACHE_DIR} \ + if [ "$BUILD_TRITON" = "1" ]; then \ mkdir -p libs \ && cd libs \ - && pip uninstall -y triton \ - && git clone https://github.com/ROCm/triton.git \ - && cd triton/python \ - && pip3 install . \ - && cd ../..; \ + && git clone https://github.com/OpenAI/triton.git \ + && cd triton \ + && git checkout "${TRITON_BRANCH}" \ + && cd python \ + && python3 setup.py bdist_wheel --dist-dir=/install; \ + # Create an empty directory otherwise as later build stages expect one + else mkdir -p /install; \ fi -WORKDIR /vllm-workspace + +### Final vLLM build stage +FROM base AS final +# Import the vLLM development directory from the build context COPY . . -#RUN python3 -m pip install pynvml # to be removed eventually -RUN python3 -m pip install --upgrade pip numba +# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. +# Manually remove it so that later steps of numpy upgrade can continue +RUN case "$(which python3)" in \ + *"/opt/conda/envs/py_3.9"*) \ + rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \ + *) ;; esac + +# Package upgrades for useful functionality or to avoid dependency issues +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --upgrade numba scipy huggingface-hub[cli] -# make sure punica kernels are built (for LoRA) +# Make sure punica kernels are built (for LoRA) ENV VLLM_INSTALL_PUNICA_KERNELS=1 # Workaround for ray >= 2.10.0 ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 +# Silences the HF Tokenizers warning +ENV TOKENIZERS_PARALLELISM=false -ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so - -ENV CCACHE_DIR=/root/.cache/ccache -RUN --mount=type=cache,target=/root/.cache/ccache \ +RUN --mount=type=cache,target=${CCACHE_DIR} \ --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ - && if [ "$BASE_IMAGE" = "$ROCm_6_0_BASE" ]; then \ - patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch; fi \ - && python3 setup.py install \ - && export VLLM_PYTHON_VERSION=$(python -c "import sys; print(str(sys.version_info.major) + str(sys.version_info.minor))") \ - && cp build/lib.linux-x86_64-cpython-${VLLM_PYTHON_VERSION}/vllm/*.so vllm/ \ - && cd .. + && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ + *"rocm-6.0"*) \ + patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \ + *"rocm-6.1"*) \ + # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM + wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \ + && cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \ + # Prevent interference if torch bundles its own HIP runtime + && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \ + *) ;; esac \ + && python3 setup.py clean --all \ + && python3 setup.py develop + +# Copy amdsmi wheel into final image +RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \ + mkdir -p libs \ + && cp /install/*.whl libs \ + # Preemptively uninstall to avoid same-version no-installs + && pip uninstall -y amdsmi; +# Copy triton wheel(s) into final image if they were built +RUN --mount=type=bind,from=build_triton,src=/install,target=/install \ + mkdir -p libs \ + && if ls /install/*.whl; then \ + cp /install/*.whl libs \ + # Preemptively uninstall to avoid same-version no-installs + && pip uninstall -y triton; fi + +# Copy flash-attn wheel(s) into final image if they were built +RUN --mount=type=bind,from=build_fa,src=/install,target=/install \ + mkdir -p libs \ + && if ls /install/*.whl; then \ + cp /install/*.whl libs \ + # Preemptively uninstall to avoid same-version no-installs + && pip uninstall -y flash-attn; fi + +# Install wheels that were built to the final image +RUN --mount=type=cache,target=/root/.cache/pip \ + if ls libs/*.whl; then \ + pip install libs/*.whl; fi CMD ["/bin/bash"] diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 071e16336dfa2..4869cad541135 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -147,19 +147,23 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) if (${GPU_LANG} STREQUAL "HIP") # # `GPU_ARCHES` controls the `--offload-arch` flags. - # `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled - # via the `PYTORCH_ROCM_ARCH` env variable. # - + # If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list, + # if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling + # "rocm_agent_enumerator" in "enable_language(HIP)" + # (in file Modules/CMakeDetermineHIPCompiler.cmake) + # + if(DEFINED ENV{PYTORCH_ROCM_ARCH}) + set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH}) + else() + set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES}) + endif() # # Find the intersection of the supported + detected architectures to # set the module architecture flags. # - - set(VLLM_ROCM_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100") - set(${GPU_ARCHES}) - foreach (_ARCH ${VLLM_ROCM_SUPPORTED_ARCHS}) + foreach (_ARCH ${HIP_ARCHITECTURES}) if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST) list(APPEND ${GPU_ARCHES} ${_ARCH}) endif() @@ -167,7 +171,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) if(NOT ${GPU_ARCHES}) message(FATAL_ERROR - "None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is" + "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is" " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.") endif() diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 61fcd45a26347..cc41d47296f8d 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -88,7 +88,7 @@ Option 2: Build from source - `Pytorch `_ - `hipBLAS `_ -For installing PyTorch, you can start from a fresh docker image, e.g, `rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`. +For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`, `rocm/pytorch-nightly`. Alternatively, you can install pytorch using pytorch wheels. You can check Pytorch installation guild in Pytorch `Getting Started `_ @@ -126,12 +126,12 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl $ cd vllm $ pip install -U -r requirements-rocm.txt - $ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation + $ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation .. tip:: - You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation. - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. - - To use CK flash-attention, please use this flag ``export VLLM_USE_FLASH_ATTN_TRITON=0`` to turn off triton flash attention. + - To use CK flash-attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention. - The ROCm version of pytorch, ideally, should match the ROCm driver version. diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index cc05d79e56874..332937b874e93 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -4,7 +4,7 @@ # and debugging. import ray -from ..utils import VLLM_PATH, RemoteOpenAIServer +from ..utils import RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "facebook/opt-125m" @@ -12,7 +12,7 @@ @pytest.fixture(scope="module") def ray_ctx(): - ray.init(runtime_env={"working_dir": VLLM_PATH}) + ray.init() yield ray.shutdown() diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 49d11daca9aec..9ff11b0d27b11 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -1,8 +1,8 @@ -import os - import ray -from vllm.utils import cuda_device_count_stateless +import vllm.envs as envs +from vllm.utils import (cuda_device_count_stateless, is_hip, + update_environment_variables) @ray.remote @@ -12,16 +12,21 @@ def get_count(self): return cuda_device_count_stateless() def set_cuda_visible_devices(self, cuda_visible_devices: str): - os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) def get_cuda_visible_devices(self): - return os.environ["CUDA_VISIBLE_DEVICES"] + return envs.CUDA_VISIBLE_DEVICES def test_cuda_device_count_stateless(): """Test that cuda_device_count_stateless changes return value if CUDA_VISIBLE_DEVICES is changed.""" - + if is_hip(): + # Set HIP_VISIBLE_DEVICES == CUDA_VISIBLE_DEVICES. Conversion + # is handled by `update_environment_variables` + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES}) actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore num_gpus=2).remote() assert sorted(ray.get( diff --git a/tests/entrypoints/test_openai_embedding.py b/tests/entrypoints/test_openai_embedding.py index 2496d2ac3e97d..45f701733df0c 100644 --- a/tests/entrypoints/test_openai_embedding.py +++ b/tests/entrypoints/test_openai_embedding.py @@ -2,7 +2,7 @@ import pytest import ray -from ..utils import VLLM_PATH, RemoteOpenAIServer +from ..utils import RemoteOpenAIServer EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" @@ -11,7 +11,7 @@ @pytest.fixture(scope="module") def ray_ctx(): - ray.init(runtime_env={"working_dir": VLLM_PATH}) + ray.init() yield ray.shutdown() diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index c22a675ff1230..5196d81815502 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -16,7 +16,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer -from ..utils import VLLM_PATH, RemoteOpenAIServer +from ..utils import RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -81,7 +81,7 @@ def zephyr_lora_files(): @pytest.fixture(scope="module") def ray_ctx(): - ray.init(runtime_env={"working_dir": VLLM_PATH}) + ray.init() yield ray.shutdown() diff --git a/tests/entrypoints/test_openai_vision.py b/tests/entrypoints/test_openai_vision.py index 03dc5d1161f0e..0e8d88b76ffec 100644 --- a/tests/entrypoints/test_openai_vision.py +++ b/tests/entrypoints/test_openai_vision.py @@ -8,7 +8,7 @@ from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64 -from ..utils import VLLM_PATH, RemoteOpenAIServer +from ..utils import RemoteOpenAIServer MODEL_NAME = "llava-hf/llava-1.5-7b-hf" LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent / @@ -27,7 +27,7 @@ @pytest.fixture(scope="module") def ray_ctx(): - ray.init(runtime_env={"working_dir": VLLM_PATH}) + ray.init() yield ray.shutdown() diff --git a/tests/utils.py b/tests/utils.py index 174efca4af532..2a5f82b91c42c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,9 +15,30 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.utils import get_open_port, is_hip -if (not is_hip()): +if is_hip(): + from amdsmi import (amdsmi_get_gpu_vram_usage, + amdsmi_get_processor_handles, amdsmi_init, + amdsmi_shut_down) + + @contextmanager + def _nvml(): + try: + amdsmi_init() + yield + finally: + amdsmi_shut_down() +else: from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, - nvmlInit) + nvmlInit, nvmlShutdown) + + @contextmanager + def _nvml(): + try: + nvmlInit() + yield + finally: + nvmlShutdown() + # Path to root of repository so that utilities can be imported by ray workers VLLM_PATH = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir)) @@ -160,20 +181,25 @@ def error_on_warning(): yield +@_nvml() def wait_for_gpu_memory_to_clear(devices: List[int], threshold_bytes: int, timeout_s: float = 120) -> None: # Use nvml instead of pytorch to reduce measurement error from torch cuda # context. - nvmlInit() start_time = time.time() while True: output: Dict[int, str] = {} output_raw: Dict[int, float] = {} for device in devices: - dev_handle = nvmlDeviceGetHandleByIndex(device) - mem_info = nvmlDeviceGetMemoryInfo(dev_handle) - gb_used = mem_info.used / 2**30 + if is_hip(): + dev_handle = amdsmi_get_processor_handles()[device] + mem_info = amdsmi_get_gpu_vram_usage(dev_handle) + gb_used = mem_info["vram_used"] / 2**10 + else: + dev_handle = nvmlDeviceGetHandleByIndex(device) + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) + gb_used = mem_info.used / 2**30 output_raw[device] = gb_used output[device] = f'{gb_used:.02f}' diff --git a/vllm/config.py b/vllm/config.py index 0217a2b569928..0c4d770e46847 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -7,13 +7,15 @@ import torch from transformers import PretrainedConfig, PreTrainedTokenizerBase +import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.tracing import is_otel_installed from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, - is_hip, is_neuron, is_tpu, is_xpu) + is_hip, is_neuron, is_tpu, is_xpu, + update_environment_variables) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -634,6 +636,12 @@ def __init__( self.distributed_executor_backend = backend logger.info("Defaulting to use %s for distributed inference", backend) + # If CUDA_VISIBLE_DEVICES is set on ROCm prior to vLLM init, + # propagate changes to HIP_VISIBLE_DEVICES (conversion handled by + # the update_environment_variables function) + if is_hip() and envs.CUDA_VISIBLE_DEVICES: + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES}) self._verify_args() diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py index d3e41fa710676..6f1aaed9881a2 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -13,7 +13,8 @@ import vllm.envs as envs from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.logger import init_logger -from vllm.utils import cuda_device_count_stateless +from vllm.utils import (cuda_device_count_stateless, + update_environment_variables) logger = init_logger(__name__) @@ -24,7 +25,8 @@ def producer(batch_src: Sequence[int], result_queue, cuda_visible_devices: Optional[str] = None): if cuda_visible_devices is not None: - os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) lib = CudaRTLibrary() for i in batch_src: @@ -56,7 +58,8 @@ def consumer(batch_tgt: Sequence[int], result_queue, cuda_visible_devices: Optional[str] = None): if cuda_visible_devices is not None: - os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) lib = CudaRTLibrary() for j in batch_tgt: @@ -123,7 +126,7 @@ def can_actually_p2p( processes for testing all pairs of GPUs in batch. The trick is to reset the device after each test (which is not available in PyTorch). """ # noqa - cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES # pass the CUDA_VISIBLE_DEVICES to the child process # to make sure they see the same set of GPUs diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index e63e5a3a027fa..a5b1d27f27596 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -11,7 +11,8 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (cuda_device_count_stateless, get_distributed_init_method, get_open_port, - get_vllm_instance_id, make_async) + get_vllm_instance_id, make_async, + update_environment_variables) logger = init_logger(__name__) @@ -25,8 +26,9 @@ def _init_executor(self) -> None: # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers if "CUDA_VISIBLE_DEVICES" not in os.environ: - os.environ["CUDA_VISIBLE_DEVICES"] = (",".join( - map(str, range(world_size)))) + update_environment_variables({ + "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) + }) # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() diff --git a/vllm/utils.py b/vllm/utils.py index f0c7df5cf8c22..92abdb3fb9b14 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -376,6 +376,10 @@ def get_open_port() -> int: def update_environment_variables(envs: Dict[str, str]): + if is_hip() and "CUDA_VISIBLE_DEVICES" in envs: + # Propagate changes to CUDA_VISIBLE_DEVICES to + # ROCm's HIP_VISIBLE_DEVICES as well + envs["HIP_VISIBLE_DEVICES"] = envs["CUDA_VISIBLE_DEVICES"] for k, v in envs.items(): if k in os.environ and os.environ[k] != v: logger.warning( @@ -779,9 +783,14 @@ def _cuda_device_count_stateless( if not torch.cuda._is_compiled(): return 0 - # bypass _device_count_nvml() if rocm (not supported) - nvml_count = -1 if torch.version.hip else torch.cuda._device_count_nvml() - r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count + if is_hip(): + # ROCm uses amdsmi instead of nvml for stateless device count + # This requires a sufficiently modern version of Torch 2.4.0 + raw_count = torch.cuda._device_count_amdsmi() if (hasattr( + torch.cuda, "_device_count_amdsmi")) else -1 + else: + raw_count = torch.cuda._device_count_nvml() + r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count return r @@ -795,7 +804,6 @@ def cuda_device_count_stateless() -> int: # This can be removed and simply replaced with torch.cuda.get_device_count # after https://github.com/pytorch/pytorch/pull/122815 is released. - return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index dc09718de4a32..99482aa93bc59 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -6,7 +6,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import (enable_trace_function_call_for_thread, +from vllm.utils import (enable_trace_function_call_for_thread, is_hip, update_environment_variables) logger = init_logger(__name__) @@ -125,6 +125,14 @@ def update_environment_variables(envs: Dict[str, str]) -> None: # overwriting CUDA_VISIBLE_DEVICES is desired behavior # suppress the warning in `update_environment_variables` del os.environ[key] + if is_hip(): + hip_env_var = "HIP_VISIBLE_DEVICES" + if hip_env_var in os.environ: + logger.warning( + "Ignoring pre-set environment variable `%s=%s` as " + "%s has also been set, which takes precedence.", + hip_env_var, os.environ[hip_env_var], key) + os.environ.pop(hip_env_var, None) update_environment_variables(envs) def init_worker(self, *args, **kwargs):