From 62a83fd80045d13bd291370baef4810e6345f9d7 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Fri, 26 Apr 2024 21:48:58 +0800 Subject: [PATCH] add intel xpu support for TGI (#1475) Fixes # (issue) - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --------- Signed-off-by: Wang, Yi A Co-authored-by: Morgan Funtowicz Co-authored-by: Nicolas Patry --- .github/workflows/build.yaml | 93 ++++++++++++ Dockerfile_intel | 105 +++++++++++++ launcher/src/env_runtime.rs | 13 +- .../models/cache_manager.py | 6 +- .../custom_modeling/flash_dbrx_modeling.py | 4 +- .../custom_modeling/flash_mixtral_modeling.py | 5 +- .../models/flash_causal_lm.py | 28 +++- .../models/flash_llama.py | 4 + .../models/flash_mistral.py | 6 +- .../models/flash_neox.py | 5 +- .../text_generation_server/models/flash_rw.py | 5 +- .../models/flash_santacoder.py | 4 + .../text_generation_server/models/globals.py | 4 +- server/text_generation_server/utils/dist.py | 10 +- .../utils/flash_attn.py | 140 +++++++++++------- .../utils/import_utils.py | 9 ++ server/text_generation_server/utils/layers.py | 39 ++++- .../utils/paged_attention.py | 30 +++- 18 files changed, 434 insertions(+), 76 deletions(-) create mode 100644 Dockerfile_intel diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 066ea889f61..f1131450bf3 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -274,12 +274,105 @@ jobs: cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min + build-and-push-image-intel: + concurrency: + group: ${{ github.workflow }}-build-and-push-image-intel-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + needs: + - start-runner + - build-and-push-image # Wait for the main docker image to be built + - integration-tests # Wait for the main integration-tests + runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner + permissions: + contents: write + packages: write + # This is used to complete the identity challenge + # with sigstore/fulcio when running outside of PRs. + id-token: write + security-events: write + steps: + - name: Checkout repository + uses: actions/checkout@v3 + - name: Initialize Docker Buildx + uses: docker/setup-buildx-action@v2.0.0 + with: + install: true + - name: Inject slug/short variables + uses: rlespinasse/github-slug-action@v4.4.1 + - name: Tailscale + uses: tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966 + with: + authkey: ${{ secrets.TAILSCALE_AUTHKEY }} + - name: Login to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Login to internal Container Registry + uses: docker/login-action@v2.1.0 + with: + username: ${{ secrets.TAILSCALE_DOCKER_USERNAME }} + password: ${{ secrets.TAILSCALE_DOCKER_PASSWORD }} + registry: registry.internal.huggingface.tech + - name: Login to Azure Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v2.1.0 + with: + username: ${{ secrets.AZURE_DOCKER_USERNAME }} + password: ${{ secrets.AZURE_DOCKER_PASSWORD }} + registry: db4c2190dd824d1f950f5d1555fbadf0.azurecr.io + # If pull request + - name: Extract metadata (tags, labels) for Docker + if: ${{ github.event_name == 'pull_request' }} + id: meta-pr + uses: docker/metadata-action@v4.3.0 + with: + images: | + registry.internal.huggingface.tech/api-inference/community/text-generation-inference + tags: | + type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel + # If main, release or tag + - name: Extract metadata (tags, labels) for Docker + if: ${{ github.event_name != 'pull_request' }} + id: meta + uses: docker/metadata-action@v4.3.0 + with: + flavor: | + latest=false + images: | + registry.internal.huggingface.tech/api-inference/community/text-generation-inference + ghcr.io/huggingface/text-generation-inference + db4c2190dd824d1f950f5d1555fbadf0.azurecr.io/text-generation-inference + tags: | + type=semver,pattern={{version}}-intel + type=semver,pattern={{major}}.{{minor}}-intel + type=raw,value=latest-intel,enable=${{ github.ref == format('refs/heads/{0}', github.event.repository.default_branch) }} + type=raw,value=sha-${{ env.GITHUB_SHA_SHORT }}-intel + - name: Build and push Docker image + id: build-and-push + uses: docker/build-push-action@v4 + with: + context: . + file: Dockerfile_intel + push: true + platforms: 'linux/amd64' + build-args: | + GIT_SHA=${{ env.GITHUB_SHA }} + DOCKER_LABEL=sha-${{ env.GITHUB_SHA_SHORT }}-intel + tags: ${{ steps.meta.outputs.tags || steps.meta-pr.outputs.tags }} + labels: ${{ steps.meta.outputs.labels || steps.meta-pr.outputs.labels }} + cache-from: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min + cache-to: type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-intel,mode=min + stop-runner: name: Stop self-hosted EC2 runner needs: - start-runner - build-and-push-image - build-and-push-image-rocm + - build-and-push-image-intel - integration-tests runs-on: ubuntu-latest env: diff --git a/Dockerfile_intel b/Dockerfile_intel new file mode 100644 index 00000000000..d0791cac12b --- /dev/null +++ b/Dockerfile_intel @@ -0,0 +1,105 @@ +FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef +WORKDIR /usr/src + +ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse + +FROM chef as planner +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo chef prepare --recipe-path recipe.json + +FROM chef AS builder + +ARG GIT_SHA +ARG DOCKER_LABEL + +RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ + curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ + unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ + unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ + rm -f $PROTOC_ZIP + +COPY --from=planner /usr/src/recipe.json recipe.json +RUN cargo chef cook --release --recipe-path recipe.json + +COPY Cargo.toml Cargo.toml +COPY rust-toolchain.toml rust-toolchain.toml +COPY proto proto +COPY benchmark benchmark +COPY router router +COPY launcher launcher +RUN cargo build --release + + +# Text Generation Inference base image for Intel +FROM intel/intel-extension-for-pytorch:2.1.10-xpu as base + +USER root +# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it +RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \ + dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb + + +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ +| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list + +RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build + +# Text Generation Inference base env +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + + +WORKDIR /usr/src +# Build pytorch and ipex +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b xpu_main origin/xpu-main +RUN git clone https://github.com/pytorch/pytorch.git && cd pytorch && git checkout 209f2fa8ff86652f67d75c2f19bf9cb9942fd018 && git apply /usr/src/intel-extension-for-pytorch/torch_patches/00*.patch + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_cuda.txt && \ + pip install ".[accelerate, peft, outlines]" --no-cache-dir + +ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest +ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest +ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric +ENV DIAGUTIL_PATH=/opt/intel/oneapi/compiler/latest/etc/compiler/sys_check/sys_check.sh +ENV CCL_CONFIGURATION=cpu_gpu_dpcpp +ENV MANPATH=/opt/intel/oneapi/mpi/latest/share/man:/opt/intel/oneapi/mpi/latest/share/man:/opt/intel/oneapi/compiler/latest/share/man +ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest +ENV CMPLR_ROOT=/opt/intel/oneapi/compiler/latest +ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib +ENV OCL_ICD_FILENAMES=libintelocl_emu.so:libalteracl.so:/opt/intel/oneapi/compiler/latest/lib/libintelocl.so +ENV CLASSPATH=/opt/intel/oneapi/mpi/latest/share/java/mpi.jar:/opt/intel/oneapi/mpi/latest/share/java/mpi.jar +ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64: +ENV MKLROOT=/opt/intel/oneapi/mkl/latest +ENV NLSPATH=/opt/intel/oneapi/mkl/latest/share/locale/%l_%t/%N:/opt/intel/oneapi/compiler/latest/lib/locale/%l_%t/%N +ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include +ENV CCL_ZE_IPC_EXCHANGE=sockets + + +RUN pip uninstall -y torch && cd pytorch && git submodule update --init --recursive && python setup.py install +RUN pip uninstall -y intel-extension-for-pytorch && cd intel-extension-for-pytorch && git submodule update --init --recursive && USE_AOT_DEVLIST='pvc' BUILD_SEPARATE_OPS=ON BUILD_WITH_CPU=ON USE_XETLA=ON python setup.py install + +# Install benchmarker +COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher + +# Final image +FROM base + +ENTRYPOINT ["text-generation-launcher"] +CMD ["--json-output"] diff --git a/launcher/src/env_runtime.rs b/launcher/src/env_runtime.rs index 9dbc83f7907..08fb301c3f4 100644 --- a/launcher/src/env_runtime.rs +++ b/launcher/src/env_runtime.rs @@ -7,14 +7,17 @@ pub(crate) struct Env { git_sha: &'static str, docker_label: &'static str, nvidia_env: String, + xpu_env: String, } impl Env { pub fn new() -> Self { let nvidia_env = nvidia_smi(); + let xpu_env = xpu_smi(); Self { nvidia_env: nvidia_env.unwrap_or("N/A".to_string()), + xpu_env: xpu_env.unwrap_or("N/A".to_string()), cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"), cargo_version: env!("VERGEN_RUSTC_SEMVER"), git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), @@ -31,7 +34,8 @@ impl fmt::Display for Env { writeln!(f, "Cargo version: {}", self.cargo_version)?; writeln!(f, "Commit sha: {}", self.git_sha)?; writeln!(f, "Docker label: {}", self.docker_label)?; - write!(f, "nvidia-smi:\n{}", self.nvidia_env)?; + writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?; + write!(f, "xpu-smi:\n{}", self.xpu_env)?; Ok(()) } @@ -43,3 +47,10 @@ fn nvidia_smi() -> Option { let output = nvidia_smi.replace('\n', "\n "); Some(output.trim().to_string()) } + +fn xpu_smi() -> Option { + let output = Command::new("xpu-smi").arg("discovery").output().ok()?; + let xpu_smi = String::from_utf8(output.stdout).ok()?; + let output = xpu_smi.replace('\n', "\n "); + Some(output.trim().to_string()) +} diff --git a/server/text_generation_server/models/cache_manager.py b/server/text_generation_server/models/cache_manager.py index 85e1b19bb20..4c65e2dd711 100644 --- a/server/text_generation_server/models/cache_manager.py +++ b/server/text_generation_server/models/cache_manager.py @@ -2,6 +2,7 @@ import torch from typing import Optional, List, Tuple +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM BLOCK_SIZE: int = 16 # Will be set in warmup @@ -24,7 +25,10 @@ def __init__( self.repeat_slots = repeat_slots element_size = torch.tensor([], dtype=dtype).element_size() - x = self.block_size // element_size + if IS_XPU_SYSTEM: + x = 1 + else: + x = self.block_size // element_size self.kv_cache = [ ( diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index d04ce39ed8b..d0978bef61f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -21,8 +21,10 @@ from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any from loguru import logger +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM -from vllm.model_executor.layers.fused_moe import fused_moe +if not IS_XPU_SYSTEM: + from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( FastLinear, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index be8cb965025..3f6c8e036a3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -24,7 +24,10 @@ import numpy as np from torch import nn -from vllm.model_executor.layers.fused_moe import fused_moe +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM + +if not IS_XPU_SYSTEM: + from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 1189ccdd1b8..94518b8ff86 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -33,7 +33,7 @@ from text_generation_server.utils.dist import MEMORY_FRACTION tracer = trace.get_tracer(__name__) - +from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM, IS_XPU_SYSTEM @dataclass class FlashCausalLMBatch(Batch): @@ -752,7 +752,10 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): def warmup(self, batch: FlashCausalLMBatch): # The warmup batch is the biggest batch we could ever receive - torch.cuda.empty_cache() + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + torch.cuda.empty_cache() + elif IS_XPU_SYSTEM: + torch.xpu.empty_cache() try: cache_manager = set_cache_manager( batch.blocks, @@ -772,7 +775,10 @@ def warmup(self, batch: FlashCausalLMBatch): f"You need to decrease `--max-batch-prefill-tokens`" ) from e - torch.cuda.synchronize(self.device) + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + torch.cuda.synchronize(self.device) + elif IS_XPU_SYSTEM: + torch.xpu.synchronize(self.device) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Calculate the number of blocks that can be allocated with the free memory @@ -780,12 +786,18 @@ def warmup(self, batch: FlashCausalLMBatch): cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size - total_free_memory, _ = torch.cuda.mem_get_info(self.device) - total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory + if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + total_free_memory, _ = torch.cuda.mem_get_info(self.device) + total_gpu_memory = torch.cuda.get_device_properties(self.device).total_memory - free_memory = max( - 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory - ) + free_memory = max( + 0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory + ) + elif IS_XPU_SYSTEM: + total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory + free_memory = int(total_gpu_memory *0.5) + else: + raise NotImplementedError("FlashModel is only available on GPU") num_blocks = ( # Leave 5% for some wiggle room diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index f3578f88f2e..f37fc542b7f 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -18,6 +18,7 @@ tracer = trace.get_tracer(__name__) +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM class FlashLlama(FlashCausalLM): def __init__( @@ -33,6 +34,9 @@ def __init__( if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_XPU_SYSTEM: + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 52a30b5f707..e2ad78d9471 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -33,8 +33,9 @@ # Will be set in init SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM -MEM_POOL = torch.cuda.graph_pool_handle() +MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None def set_sliding_window(sliding_window: int, sliding_window_blocks: int): @@ -316,6 +317,9 @@ def __init__( if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_XPU_SYSTEM: + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashMistral is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 5a351bd7281..70c978def2c 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -14,7 +14,7 @@ weight_files, Weights, ) - +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM tracer = trace.get_tracer(__name__) @@ -32,6 +32,9 @@ def __init__( if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_XPU_SYSTEM: + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index fc1e26bd2c2..6eb25f22507 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -15,7 +15,7 @@ weight_files, Weights, ) - +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM tracer = trace.get_tracer(__name__) @@ -33,6 +33,9 @@ def __init__( if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_XPU_SYSTEM: + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 034949f9842..6147398aeff 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -18,6 +18,7 @@ Weights, ) +from text_generation_server.utils.import_utils import IS_XPU_SYSTEM tracer = trace.get_tracer(__name__) @@ -35,6 +36,9 @@ def __init__( if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IS_XPU_SYSTEM: + device = torch.device(f"xpu:{rank}") + dtype = torch.float16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 91b4225aaec..b92aa65bd40 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -1,10 +1,10 @@ import torch import os -MEM_POOL = torch.cuda.graph_pool_handle() +MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None # This is overridden by the cli cuda_graphs = os.getenv("CUDA_GRAPHS") -if cuda_graphs is not None and cuda_graphs != "0": +if torch.cuda.is_available() and cuda_graphs is not None and cuda_graphs != "0": try: cuda_graphs = [int(item) for item in cuda_graphs.split(",")] except Exception as e: diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index ad170e44354..d370a3d5cea 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -68,7 +68,15 @@ def initialize_torch_distributed(): if world_size > n_hpus: raise ValueError(f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus}).") else: - backend = "gloo" + try: + import oneccl_bindings_for_pytorch + + backend = "ccl" + if os.getenv("CCL_WORKER_COUNT", None) is None: + os.environ["CCL_WORKER_COUNT"] = str(1) + except ImportError: + backend = "gloo" + options = None if WORLD_SIZE == 1: return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE diff --git a/server/text_generation_server/utils/flash_attn.py b/server/text_generation_server/utils/flash_attn.py index 45090c648fe..583a8f91e4d 100644 --- a/server/text_generation_server/utils/flash_attn.py +++ b/server/text_generation_server/utils/flash_attn.py @@ -2,69 +2,81 @@ import torch from loguru import logger +import math -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.import_utils import ( + IS_CUDA_SYSTEM, + IS_ROCM_SYSTEM, + IS_XPU_SYSTEM, +) if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") +HAS_FLASH_ATTN = True +HAS_FLASH_ATTN_V2_CUDA = False +HAS_FLASH_ATTN_V2_ROCM = False -if not torch.cuda.is_available(): - raise ImportError("CUDA is not available") +if IS_XPU_SYSTEM: + import intel_extension_for_pytorch as ipex -major, minor = torch.cuda.get_device_capability() -is_sm75 = major == 7 and minor == 5 -is_sm8x = major == 8 and minor >= 0 -is_sm90 = major == 9 and minor == 0 +if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: + if not torch.cuda.is_available(): + raise ImportError("CUDA is not available") -HAS_FLASH_ATTN = False -HAS_FLASH_ATTN_V2_CUDA = False -HAS_FLASH_ATTN_V2_ROCM = False -try: + major, minor = torch.cuda.get_device_capability() + is_sm75 = major == 7 and minor == 5 + is_sm8x = major == 8 and minor >= 0 + is_sm90 = major == 9 and minor == 0 + + HAS_FLASH_ATTN = False + HAS_FLASH_ATTN_V2_CUDA = False + HAS_FLASH_ATTN_V2_ROCM = False try: - import flash_attn_2_cuda - except ImportError: - architecture_suffix = "" - if IS_CUDA_SYSTEM: - architecture_suffix = "-cuda" + try: + import flash_attn_2_cuda + except ImportError: + architecture_suffix = "" + if IS_CUDA_SYSTEM: + architecture_suffix = "-cuda" + elif IS_ROCM_SYSTEM: + architecture_suffix = "-rocm" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + if not (is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported for " + "Flash Attention V2" + ) + HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM + HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM + except ImportError as e: + try: + import flash_attn_cuda + except ImportError: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + + if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90): + raise ImportError( + f"GPU with CUDA capability {major} {minor} is not supported" + ) from e elif IS_ROCM_SYSTEM: - architecture_suffix = "-rocm" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - if not (is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported for " - "Flash Attention V2" - ) - HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM - HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM -except ImportError as e: - try: - import flash_attn_cuda - except ImportError: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - - if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90): - raise ImportError( - f"GPU with CUDA capability {major} {minor} is not supported" - ) from e - elif IS_ROCM_SYSTEM: - for idx in range(torch.cuda.device_count()): - if "MI210" not in torch.cuda.get_device_name( - idx - ) and "MI250" not in torch.cuda.get_device_name(idx): - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) + for idx in range(torch.cuda.device_count()): + if "MI210" not in torch.cuda.get_device_name( + idx + ) and "MI250" not in torch.cuda.get_device_name(idx): + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) - logger.warning(f"Unable to use Flash Attention V2: {e}") - HAS_FLASH_ATTN = True + logger.warning(f"Unable to use Flash Attention V2: {e}") + HAS_FLASH_ATTN = True def attention( @@ -80,6 +92,28 @@ def attention( if window_size_left <= 0 and window_size_left != -1: raise ValueError("`window_size_left` must be > 0 or -1") + if IS_XPU_SYSTEM: + if window_size_left != -1: + raise ValueError( + f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." + ) + return ipex.llm.functional.varlen_attention( + q, + k, + v, + out, + cu_seqlens, + cu_seqlens, + max_s, + max_s, + 0.0, + softmax_scale, + False, + True, + False, + None, + ) + if HAS_FLASH_ATTN_V2_CUDA: return flash_attn_2_cuda.varlen_fwd( q, diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index 428c9f3efc9..7c0d8001d17 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -1,4 +1,13 @@ import torch +def is_xpu_available(): + try: + import intel_extension_for_pytorch + except ImportError: + return False + + return hasattr(torch, "xpu") and torch.xpu.is_available() + IS_ROCM_SYSTEM = torch.version.hip is not None IS_CUDA_SYSTEM = torch.version.cuda is not None +IS_XPU_SYSTEM = is_xpu_available() diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 69bd5e88a44..638cb0a0c1e 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -18,7 +18,14 @@ from accelerate import init_empty_weights from text_generation_server.utils.gptq.quant_linear import QuantLinear -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.import_utils import ( + IS_CUDA_SYSTEM, + IS_ROCM_SYSTEM, + IS_XPU_SYSTEM, +) + +if IS_XPU_SYSTEM: + import intel_extension_for_pytorch as ipex HAS_AWQ = True try: @@ -812,7 +819,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: + if IS_XPU_SYSTEM: + res_out = hidden_states + out = ipex.llm.functional.add_layer_norm( + residual, hidden_states, self.weight, self.bias, self.eps, True + ) + if residual is not None: + res_out = residual + return out, res_out + elif hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM: if residual is not None: hidden_states += residual residual = hidden_states @@ -858,7 +873,20 @@ def load(cls, prefix, weights, eps=1e-6): return cls(weight, eps) def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: + if IS_XPU_SYSTEM: + residual_out = hidden_states + out = ipex.llm.functional.add_rms_norm( + residual, + hidden_states, + self.weight, + None, + self.variance_epsilon, + True, + ) + if residual is not None: + residual_out = residual + return out, residual_out + elif hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual residual = hidden_states @@ -984,6 +1012,10 @@ def forward( # Inplace operation, updating query and key. pos_encoding_ops.rotary_embedding(query, key, head_size, cos, sin, True) + elif IS_XPU_SYSTEM: + ipex.llm.functional.rotary_embedding( + query, key, sin, cos, query.size(-1), True + ) else: raise ValueError( "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." @@ -1103,6 +1135,7 @@ def get_cos_sin( cos = torch.index_select(self._cos_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids) + # Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow. return cos.unsqueeze(1), sin.unsqueeze(1) diff --git a/server/text_generation_server/utils/paged_attention.py b/server/text_generation_server/utils/paged_attention.py index 487a3a72e72..62c0c893f6a 100644 --- a/server/text_generation_server/utils/paged_attention.py +++ b/server/text_generation_server/utils/paged_attention.py @@ -1,9 +1,15 @@ import torch - -from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM +from text_generation_server.utils.import_utils import ( + IS_CUDA_SYSTEM, + IS_ROCM_SYSTEM, + IS_XPU_SYSTEM, +) _PARTITION_SIZE = 512 +if IS_XPU_SYSTEM: + import intel_extension_for_pytorch as ipex + def reshape_and_cache( key: torch.Tensor, @@ -22,6 +28,10 @@ def reshape_and_cache( from vllm import cache_ops cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) + elif IS_XPU_SYSTEM: + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slots + ) else: raise ValueError("vllm is not supported on your system") @@ -58,6 +68,22 @@ def attention( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE + if IS_XPU_SYSTEM: + query = query.contiguous() + return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( + out, + query, + key_cache, + value_cache, + kv_head_mapping, + softmax_scale, + block_tables, + input_lengths, + block_size, + max_s, + None, + ) + # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use # V1 to avoid the overhead of reduction. Also, if the number of