diff --git a/.buildkite/run-xpu-test.sh b/.buildkite/run-xpu-test.sh new file mode 100644 index 0000000000000..22a7e76937a76 --- /dev/null +++ b/.buildkite/run-xpu-test.sh @@ -0,0 +1,14 @@ +# This script build the CPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t xpu-test -f Dockerfile.xpu . + +# Setup cleanup +remove_docker_container() { docker rm -f xpu-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and launch offline inference +docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path xpu-test python3 examples/offline_inference.py diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 4a20a462b98ec..3bd1e90c2b711 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -45,6 +45,11 @@ steps: queue: intel command: bash .buildkite/run-cpu-test.sh + - label: "XPU Test" + agents: + queue: intel + command: bash .buildkite/run-xpu-test.sh + {% for step in steps %} - label: "{{ step.label }}" agents: diff --git a/Dockerfile.xpu b/Dockerfile.xpu new file mode 100644 index 0000000000000..c39e551672d20 --- /dev/null +++ b/Dockerfile.xpu @@ -0,0 +1,22 @@ +FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04 + +RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \ + echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ + chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \ + rm /etc/apt/sources.list.d/intel-graphics.list && \ + wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \ + echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ + chmod 644 /usr/share/keyrings/intel-graphics.gpg + +RUN apt-get update -y \ +&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip + +COPY ./ /workspace/vllm + +WORKDIR /workspace/vllm + +RUN pip install -v -r requirements-xpu.txt + +RUN VLLM_TARGET_DEVICE=xpu python3 setup.py install + +CMD ["/bin/bash"] diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 9937f8333fb7e..11d1bf7a4c58f 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -191,7 +191,7 @@ def run_to_completion(profile_dir: Optional[str] = None): "--device", type=str, default="cuda", - choices=["cuda", "cpu", "tpu"], + choices=["cuda", "cpu", "tpu", "xpu"], help='device type for vLLM execution, supporting CUDA and CPU.') parser.add_argument('--block-size', type=int, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 48dfce4287671..ed65002bc7d3c 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -349,7 +349,7 @@ def main(args: argparse.Namespace): "--device", type=str, default="cuda", - choices=["cuda", "cpu", "tpu"], + choices=["cuda", "cpu", "tpu", "xpu"], help='device type for vLLM execution, supporting CUDA and CPU.') parser.add_argument( "--enable-prefix-caching", diff --git a/docs/source/getting_started/xpu-installation.rst b/docs/source/getting_started/xpu-installation.rst new file mode 100644 index 0000000000000..4f0d2da25b8e8 --- /dev/null +++ b/docs/source/getting_started/xpu-installation.rst @@ -0,0 +1,61 @@ +.. _installation_xpu: + +Installation with XPU +======================== + +vLLM initially supports basic model inferencing and serving on Intel GPU platform. + +Table of contents: + +#. :ref:`Requirements ` +#. :ref:`Quick start using Dockerfile ` +#. :ref:`Build from source ` + +.. _xpu_backend_requirements: + +Requirements +------------ + +* OS: Linux +* Supported Hardware: Intel Data Center GPU (Intel ARC GPU WIP) +* OneAPI requirements: oneAPI 2024.1 + +.. _xpu_backend_quick_start_dockerfile: + +Quick start using Dockerfile +---------------------------- + +.. code-block:: console + + $ docker build -f Dockerfile.xpu -t vllm-xpu-env --shm-size=4g . + $ docker run -it \ + --rm \ + --network=host \ + --device /dev/dri \ + -v /dev/dri/by-path:/dev/dri/by-path \ + vllm-xpu-env + +.. _build_xpu_backend_from_source: + +Build from source +----------------- + +- First, install required driver and intel OneAPI 2024.1. + +- Second, install Python packages for vLLM XPU backend building: + +.. code-block:: console + + $ pip install --upgrade pip + $ pip install -v -r requirements-xpu.txt + +- Finally, build and install vLLM XPU backend: + +.. code-block:: console + + $ VLLM_TARGET_DEVICE=xpu python setup.py install + +.. note:: + - FP16 is the default data type in the current XPU backend. The BF16 data + type will be supported in the future. + diff --git a/docs/source/index.rst b/docs/source/index.rst index f5d8627596a70..8795a865c3db6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -66,6 +66,7 @@ Documentation getting_started/cpu-installation getting_started/neuron-installation getting_started/tpu-installation + getting_started/xpu-installation getting_started/quickstart getting_started/debugging getting_started/examples/examples_index diff --git a/requirements-xpu.txt b/requirements-xpu.txt new file mode 100644 index 0000000000000..48d899ec70eda --- /dev/null +++ b/requirements-xpu.txt @@ -0,0 +1,11 @@ +# Common dependencies +-r requirements-common.txt + +setuptools < 70.0.0 # IPEX's torch have some dependency. to be removed. + +torch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl +intel_extension_for_pytorch @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.1.30a0-cp310-cp310-linux_x86_64.whl +oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.1.200%2Bxpu-cp310-cp310-linux_x86_64.whl + +triton @ https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v2.1.0/triton-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl + diff --git a/setup.py b/setup.py index 12a704e08eedb..b2ae6def8cdc6 100644 --- a/setup.py +++ b/setup.py @@ -233,6 +233,10 @@ def _is_cpu() -> bool: return VLLM_TARGET_DEVICE == "cpu" +def _is_xpu() -> bool: + return VLLM_TARGET_DEVICE == "xpu" + + def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() @@ -337,6 +341,8 @@ def get_vllm_version() -> str: version += "+tpu" elif _is_cpu(): version += "+cpu" + elif _is_xpu(): + version += "+xpu" else: raise RuntimeError("Unknown runtime environment") @@ -386,6 +392,8 @@ def _read_requirements(filename: str) -> List[str]: requirements = _read_requirements("requirements-tpu.txt") elif _is_cpu(): requirements = _read_requirements("requirements-cpu.txt") + elif _is_xpu(): + requirements = _read_requirements("requirements-xpu.txt") else: raise ValueError( "Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 2f84b8bde6b57..ab2a67950bfea 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -373,7 +373,8 @@ def reshape_and_cache_flash( kv_cache_dtype) -def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, +def copy_blocks(key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], block_mapping: torch.Tensor) -> None: torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py new file mode 100644 index 0000000000000..1e60e0848673b --- /dev/null +++ b/vllm/_ipex_ops.py @@ -0,0 +1,241 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +try: + import intel_extension_for_pytorch as ipex +except ImportError as e: + logger.warning("Import error msg: %s", e.msg) + + +class ipex_ops: + + @staticmethod + def _reshape_activation_tensor( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + num = x.size(0) + d = x.size(1) // 2 + x = x.reshape(num, 2, d) + x1, x2 = torch.chunk(x, chunks=2, dim=1) + x1 = x1.reshape(num, d) + x2 = x2.reshape(num, d) + return x1, x2 + + def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + x1, x2 = ipex_ops._reshape_activation_tensor(x) + ipex.llm.functional.silu_mul(x1, x2, out) + + def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + x1, x2 = ipex_ops._reshape_activation_tensor(x) + ipex.llm.functional.gelu_mul(x1, x2, out, "none") + + def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + x1, x2 = ipex_ops._reshape_activation_tensor(x) + ipex.llm.functional.gelu_mul(x1, x2, out, "tanh") + + def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: + out.copy_(torch.nn.functional.gelu(x)) + + def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: + out.copy_(torch.nn.functional.gelu(x)) + + def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> None: + assert kv_cache_dtype == "auto" + num_heads = out.size(1) + num_queries_per_tokens = num_heads // num_kv_heads + head_mapping = torch.arange( + 0, + num_kv_heads, + device=query.device, + dtype=torch.int32, + ).view(num_kv_heads, + 1).repeat_interleave(num_queries_per_tokens).flatten() + # todo: ipex will refactor namespace + torch.xpu.paged_attention_v1(out, query.contiguous(), + key_cache.view_as(value_cache), + value_cache, head_mapping, scale, + block_tables, context_lens, block_size, + max_context_len, alibi_slopes) + + def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + kv_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> None: + assert kv_cache_dtype == "auto" + num_heads = out.size(1) + num_queries_per_tokens = num_heads // num_kv_heads + head_mapping = torch.arange( + 0, + num_kv_heads, + dtype=torch.int32, + device=query.device, + ).view(num_kv_heads, + 1).repeat_interleave(num_queries_per_tokens).flatten() + # todo: ipex will refactor namespace + torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out, + query.contiguous(), + key_cache.view_as(value_cache), + value_cache, head_mapping, block_tables, + context_lens, scale, block_size, + max_context_len, alibi_slopes) + + def rotary_embedding( + positions: torch.Tensor, # [batch_size, seq_len] + query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size] + key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size] + head_size: int, + cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] + is_neox: bool, + ) -> None: + if positions.dim() == 1: + positions = positions.unsqueeze(0) + query = query.unsqueeze(0) + key = key.unsqueeze(0) + + rotary_dim = cos_sin_cache.size(1) + query = query.view(*query.shape[:-1], -1, head_size) + key = key.view(*key.shape[:-1], -1, head_size) + + query_rot = query[..., :rotary_dim] + key_rot = key[..., :rotary_dim] + + cos_sin = cos_sin_cache[positions.long()] + cos, sin = cos_sin.chunk(2, dim=-1) + + if is_neox: + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, + rotary_dim, is_neox, positions) + + def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + if positions.dim() == 1: + positions = positions.unsqueeze(0) + query = query.unsqueeze(0) + key = key.unsqueeze(0) + cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions) + rotary_dim = cos_sin_cache.size(1) + query = query.view(*query.shape[:-1], -1, head_size) + key = key.view(*key.shape[:-1], -1, head_size) + + query_rot = query[..., :rotary_dim] + key_rot = key[..., :rotary_dim] + + cos_sin = cos_sin_cache[torch.add(positions, + cos_sin_cache_offsets).long()] + cos, sin = cos_sin.chunk(2, dim=-1) + + if is_neox: + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos, + rotary_dim, is_neox, positions) + + def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + tmp = ipex.llm.functional.rms_norm(input, weight, epsilon) + out.copy_(tmp) + + def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None, + epsilon, True) + input.copy_(tmp) + + def varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + seqlen_q: torch.Tensor, + seqlen_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + pdropout: float, + softmax_scale: float, + zero_tensors: bool, + is_causal: bool, + return_softmax: bool, + gen_: torch.Generator, + ) -> None: + ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q, + seqlen_k, max_seqlen_q, + max_seqlen_k, pdropout, + softmax_scale, zero_tensors, + is_causal, return_softmax, gen_) + + def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + kv_scale: float, + ) -> None: + assert kv_cache_dtype == "auto" + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slot_mapping) + + @staticmethod + def copy_blocks(key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.xpu.copy_blocks(key_caches, value_caches, block_mapping) + + def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + torch.xpu.swap_blocks(src, dst, block_mapping) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py new file mode 100644 index 0000000000000..f09b24f2a0304 --- /dev/null +++ b/vllm/attention/backends/ipex_attn.py @@ -0,0 +1,355 @@ +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) + +_PARTITION_SIZE = 512 + + +class IpexAttnBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "ipex-attn" + + @staticmethod + def get_impl_cls() -> Type["IpexAttnBackendImpl"]: + return IpexAttnBackendImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "IpexAttnMetadata": + return IpexAttnMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for IpexAttnBackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor + seq_lens: Optional[List[int]] + seqlen_q: Optional[torch.Tensor] + max_seqlen: Optional[int] + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + + @property + def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + ) -> None: + assert blocksparse_params is None, ValueError( + "Torch SPDA does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "IPEX backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + + def split_kv_cache( + self, + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 1 + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: IpexAttnMetadata, # type: ignore + kv_scale: float = 1.0, + ) -> torch.Tensor: + """Forward pass with IPEX varlen_attention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert kv_scale == 1.0 + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = self.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + ipex_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + kv_scale, + ) + + if attn_metadata.is_prompt: + assert attn_metadata.seq_lens is not None + if (kv_cache is None or attn_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + + if attn_metadata.attn_bias is None: + if self.alibi_slopes is not None: + att_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, None, dtype=query.dtype) + attn_metadata.attn_bias = att_masks + + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + ipex_ops.varlen_attention(query, + key, + value, + output, + attn_metadata.seqlen_q, + attn_metadata.seqlen_q, + attn_metadata.max_seqlen, + attn_metadata.max_seqlen, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None) + else: + # prefix-enabled attention + raise RuntimeError( + "IPEX backend doesn't support prefix decoding.") + + else: + # Decoding run. + max_seq_len = attn_metadata.max_decode_seq_len + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory + # shortage. + use_v1 = (max_seq_len <= 8192 and + (max_num_partitions == 1 or num_seqs * num_heads > 512)) + if use_v1: + # Run PagedAttention V1. + ipex_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + kv_scale, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ipex_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + kv_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype, + device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 8b07fb2d768f5..1d56d87ccd119 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -7,7 +7,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger -from vllm.utils import is_cpu, is_hip, is_tpu +from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu logger = init_logger(__name__) @@ -19,6 +19,7 @@ class _Backend(enum.Enum): TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() PALLAS = enum.auto() + IPEX = enum.auto() @lru_cache(maxsize=None) @@ -58,12 +59,17 @@ def get_attn_backend( ROCmFlashAttentionBackend) return ROCmFlashAttentionBackend elif backend == _Backend.TORCH_SDPA: - # TODO: make XPU backend available here. assert is_cpu(), RuntimeError( "Torch SDPA backend is only used for the CPU device.") logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend return TorchSDPABackend + elif backend == _Backend.IPEX: + assert is_xpu(), RuntimeError( + "IPEX attention backend is only used for the XPU device.") + logger.info("Using IPEX attention backend.") + from vllm.attention.backends.ipex_attn import IpexAttnBackend + return IpexAttnBackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") logger.warning("Eager mode is required for the Flashinfer backend. " @@ -107,6 +113,11 @@ def which_attn_to_use( logger.info("Cannot use %s backend on CPU.", selected_backend) return _Backend.TORCH_SDPA + if is_xpu(): + if selected_backend != _Backend.IPEX: + logger.info("Cannot use %s backend on XPU.", selected_backend) + return _Backend.IPEX + if is_tpu(): if selected_backend != _Backend.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) diff --git a/vllm/config.py b/vllm/config.py index 552d5033fdb9d..b1a3a82f5a6c0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -12,7 +12,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, - is_hip, is_neuron, is_tpu) + is_hip, is_neuron, is_tpu, is_xpu) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -757,6 +757,8 @@ def __init__(self, device: str = "auto") -> None: self.device_type = "tpu" elif is_cpu(): self.device_type = "cpu" + elif is_xpu(): + self.device_type = "xpu" else: # We don't call torch.cuda.is_available() here to # avoid initializing CUDA before workers are forked diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 16c5297af1b53..02b0dcbcb6b24 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -58,7 +58,7 @@ def _split_tensor_dict( # because it contains not only the device type but also the device # index (e.g. "cuda:0"). We only need the device type. # receiving side will set the device index. - device = "cpu" if value.is_cpu else "cuda" + device = value.device.type metadata_list.append( (key, TensorMetadata(device, value.dtype, value.size()))) tensor_list.append(value) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ba53b5c86fa72..9d04f1dc557fd 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -501,11 +501,12 @@ def add_cli_args( 'Enabling this will use the fully sharded layers. ' 'At high sequence length, max rank or ' 'tensor parallel size, this is likely faster.')) - parser.add_argument("--device", - type=str, - default=EngineArgs.device, - choices=["auto", "cuda", "neuron", "cpu", "tpu"], - help='Device type for vLLM execution.') + parser.add_argument( + "--device", + type=str, + default=EngineArgs.device, + choices=["auto", "cuda", "neuron", "cpu", "tpu", "xpu"], + help='Device type for vLLM execution.') # Related to Vision-language models such as llava parser = EngineArgs.add_cli_args_for_vlm(parser) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 03b6d03a9fdef..ab312850b9ec2 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -383,6 +383,17 @@ def from_engine_args( "Distributed execution is not supported with the CPU backend.") from vllm.executor.cpu_executor import CPUExecutorAsync executor_class = CPUExecutorAsync + elif engine_config.device_config.device_type == "xpu": + if distributed_executor_backend is None: + from vllm.executor.xpu_executor import XPUExecutorAsync + executor_class = XPUExecutorAsync + elif distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync + executor_class = RayXPUExecutorAsync + else: + raise RuntimeError( + "Not supported distributed execution model on XPU device.") elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fd64337d4384c..eed9a17e477f3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -347,6 +347,14 @@ def from_engine_args( elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor + elif engine_config.device_config.device_type == "xpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_xpu_executor import RayXPUExecutor + executor_class = RayXPUExecutor + else: + from vllm.executor.xpu_executor import XPUExecutor + executor_class = XPUExecutor elif distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 4704f5f1b1a10..495fddd175dd4 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -3,7 +3,7 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import get_ip, is_hip +from vllm.utils import get_ip, is_hip, is_xpu from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -71,7 +71,7 @@ def initialize_ray_cluster( "serving.") # Connect to a ray cluster. - if is_hip(): + if is_hip() or is_xpu(): ray.init(address=ray_address, ignore_reinit_error=True, num_gpus=parallel_config.world_size) diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py new file mode 100644 index 0000000000000..dd7c82289341e --- /dev/null +++ b/vllm/executor/ray_xpu_executor.py @@ -0,0 +1,401 @@ +import asyncio +import os +import pickle +from collections import defaultdict +from itertools import islice, repeat +from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, + Tuple, Union) + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.ray_utils import RayWorkerWrapper, ray +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) + +if ray is not None: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + +# If the env var is set, it uses the Ray's compiled DAG API +# which optimizes the control plane overhead. +# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. +USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) + + +class RayXPUExecutor(DistributedGPUExecutor): + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], + ) -> None: + assert device_config.device_type == "xpu" + assert (not speculative_config + ), "Speculative decoding not yet supported for XPU backend" + + self.model_config = model_config + self.cache_config = cache_config + self.load_config = load_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.vision_language_config = vision_language_config + + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + # Create the parallel GPU workers. + self._init_workers_ray(placement_group) + + # Profile the memory usage and initialize the cache. + self.forward_dag = None + if USE_RAY_COMPILED_DAG: + self.forward_dag = self._compiled_ray_dag() + + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + # Updated by implementations that require additional args to be passed + # to the _run_workers execute_model call + self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} + + def _init_executor(self) -> None: + pass + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - Tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks", ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + if self.parallel_config.tensor_parallel_size == 1: + # For single GPU case, we use a ray worker with constrained memory. + num_gpus = self.cache_config.gpu_memory_utilization + else: + # Otherwise, the ray workers are allocated with a full GPU. + num_gpus = 1 + + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Create the workers. + driver_ip = get_ip() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("GPU", 0): + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + worker = ray.remote( + num_cpus=0, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote( + worker_module_name="vllm.worker.xpu_worker", + worker_class_name="XPUWorker", + trust_remote_code=self.model_config.trust_remote_code, + ) + + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name="vllm.worker.xpu_worker", + worker_class_name="XPUWorker", + trust_remote_code=self.model_config.trust_remote_code, + ) + else: + # Else, added to the list of workers. + self.workers.append(worker) + if self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node.") + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) + + node_workers = defaultdict(list) + node_gpus = defaultdict(list) + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + # TODO: add env var for xpu + + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + def collect_arg_helper_func(**kwargs): + # avoid writing `{"name": value}` manually + return kwargs + + init_worker_all_kwargs = [] + + # Initialize the actual workers inside worker wrapper. + for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ): + local_rank = node_workers[node_id].index(rank) + init_worker_all_kwargs.append( + collect_arg_helper_func( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + is_driver_worker=rank == 0, + )) + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + + self._run_workers("init_device") + self._run_workers( + "load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers, + ) + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, " + "# CPU blocks: %d", num_gpu_blocks, num_cpu_blocks) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_method("execute_model", + execute_model_req) + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> Set[int]: + return self._run_workers("list_loras") + + def _run_workers( + self, + method: str, + *args, + async_run_remote_workers_only: bool = False, + all_args: Optional[List[Tuple[Any, ...]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + use_dummy_driver: bool = False, + max_concurrent_workers: Optional[int] = None, + use_ray_compiled_dag: bool = False, + **kwargs, + ) -> Any: + """Runs the given method on all workers. Can be used in the following + ways: + + - args/kwargs: All workers share the same args/kwargs + - args/kwargs and driver_args/driver_kwargs: Driver worker has + different args + - all_args/all_kwargs: args/kwargs for each worker are specified + individually + """ + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + count = len(self.workers) + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, 1, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, 1, None) + + if use_ray_compiled_dag: + # Right now, compiled DAG can only accept a single + # input. TODO(sang): Fix it. + assert self.forward_dag is not None + output_channels = self.forward_dag.execute(1) + else: + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, + **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ] + if async_run_remote_workers_only: + # Just return futures + return ray_worker_outputs + + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) + else: + assert self.driver_dummy_worker is not None + driver_worker_output = ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) + # Get the results of the ray workers. + if self.workers: + if use_ray_compiled_dag: + try: + ray_worker_outputs = [ + pickle.loads(chan.begin_read()) + for chan in output_channels + ] + finally: + # Has to call end_read in order to reuse the DAG. + for chan in output_channels: + chan.end_read() + else: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return [driver_worker_output] + ray_worker_outputs + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + + def _compiled_ray_dag(self): + import pkg_resources + required_version = "2.9" + current_version = pkg_resources.get_distribution("ray").version + if current_version < required_version: + raise ValueError(f"Ray version {required_version} or greater is " + f"required, but found {current_version}") + + from ray.dag import InputNode, MultiOutputNode + assert self.parallel_config.worker_use_ray + + # Right now, compiled DAG requires at least 1 arg. We send + # a dummy value for now. It will be fixed soon. + with InputNode() as input_data: + forward_dag = MultiOutputNode([ + worker.execute_model_compiled_dag_remote. + bind( # type: ignore[attr-defined] + input_data) for worker in self.workers + ]) + return forward_dag.experimental_compile() + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + self._check_if_any_actor_is_dead() + + def _check_if_any_actor_is_dead(self): + if not self.workers: + return + + dead_actors = [] + for actor in self.workers: + actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access + if actor_state["State"] == "DEAD": + dead_actors.append(actor) + if dead_actors: + raise RuntimeError("At least one Worker is dead. " + f"Dead Workers: {dead_actors}. ") + + +class RayXPUExecutorAsync(RayXPUExecutor, DistributedGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_method = make_async(self.driver_worker.execute_method) + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_method("execute_model", + execute_model_req) + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.workers + ] + return await asyncio.gather(*coros) diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py new file mode 100644 index 0000000000000..d37200bd02de3 --- /dev/null +++ b/vllm/executor/xpu_executor.py @@ -0,0 +1,98 @@ +from typing import List, Optional + +import torch + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import make_async +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + + +class XPUExecutor(GPUExecutor): + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], + ) -> None: + assert device_config.device_type == "xpu" + assert (not speculative_config + ), "Speculative decoding not yet supported for XPU backend" + + model_config = _verify_and_get_model_config(model_config) + + self.model_config = model_config + self.cache_config = cache_config + self.load_config = load_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.vision_language_config = vision_language_config + self.speculative_config = None + + # Instantiate the worker and load the model to GPU. + self._init_executor() + + def _create_worker(self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None): + if self.speculative_config is None: + worker_module_name = "vllm.worker.xpu_worker" + worker_class_name = "XPUWorker" + else: + raise NotImplementedError( + "XPU does not support speculative decoding") + + wrapper = WorkerWrapperBase( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + ) + wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, + distributed_init_method)) + return wrapper.worker + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) + return output + + +class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req) + return output + + +def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: + if config.dtype == torch.bfloat16: + logger.warning( + "bfloat16 is not fully supported on XPU, casting to float16.") + config.dtype = torch.float16 + if not config.enforce_eager: + logger.warning( + "CUDA graph is not supported on XPU, fallback to the eager " + "mode.") + config.enforce_eager = True + return config diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 56aa629ae3455..0db72d8d95f24 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,6 +1,6 @@ import torch.nn as nn -from vllm.utils import is_cpu, is_hip, is_tpu +from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu class CustomOp(nn.Module): @@ -29,9 +29,7 @@ def forward_hip(self, *args, **kwargs): return self.forward_cuda(*args, **kwargs) def forward_xpu(self, *args, **kwargs): - # By default, we assume that XPU ops are compatible with CUDA ops. - # NOTE(woosuk): This is a placeholder for future extensions. - return self.forward_cuda(*args, **kwargs) + raise NotImplementedError def forward_cpu(self, *args, **kwargs): # By default, we assume that CPU ops are compatible with CUDA ops. @@ -58,5 +56,7 @@ def dispatch_forward(self): return self.forward_cpu elif is_tpu(): return self.forward_tpu + elif is_xpu(): + return self.forward_xpu else: return self.forward_cuda diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 4d076421f9d2a..eb0606948686d 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -37,6 +37,15 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: ops.silu_and_mul(out, x) return out + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + ops.silu_and_mul(out, x) + return out + class GeluAndMul(CustomOp): """An activation function for GeGLU. @@ -71,6 +80,18 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_tanh_and_mul(out, x) return out + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + if self.approximate == "none": + ops.gelu_and_mul(out, x) + elif self.approximate == "tanh": + ops.gelu_tanh_and_mul(out, x) + return out + def extra_repr(self) -> str: return f'approximate={repr(self.approximate)}' @@ -90,6 +111,13 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_new(out, x) return out + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + out = torch.empty_like(x) + ops.gelu_new(out, x) + return out + class FastGELU(CustomOp): @@ -105,6 +133,13 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_fast(out, x) return out + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + out = torch.empty_like(x) + ops.gelu_fast(out, x) + return out + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 4533adf8f83aa..14f5e2378a421 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -67,6 +67,30 @@ def forward_cuda( ) return out + def forward_xpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + from vllm._ipex_ops import ipex_ops as ops + + if residual is not None: + ops.fused_add_rms_norm( + x, + residual, + self.weight.data, + self.variance_epsilon, + ) + return x, residual + out = torch.empty_like(x) + ops.rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + return out + def extra_repr(self) -> str: s = f"hidden_size={self.weight.data.size(0)}" s += f", eps={self.variance_epsilon}" diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 5a4940acbbef2..9c0a74cdab96e 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -221,6 +221,29 @@ def forward_cuda( self.cos_sin_cache, self.is_neox_style) return query, key + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm._ipex_ops import ipex_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + def forward_tpu( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 60eb5b404e2ca..1a26c5c63fedc 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -307,7 +307,7 @@ def forward(self, input_): else: masked_input = input_ # Get the embeddings. - output_parallel = F.embedding(masked_input, self.weight) + output_parallel = F.embedding(masked_input.long(), self.weight) # Mask the output embedding. if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(1), 0) diff --git a/vllm/utils.py b/vllm/utils.py index 9b39ca77a9801..1adfa9218c047 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -160,6 +160,26 @@ def is_tpu() -> bool: return libtpu is not None +@lru_cache(maxsize=None) +def is_xpu() -> bool: + from importlib.metadata import version + is_xpu_flag = "xpu" in version("vllm") + # vllm is not build with xpu + if not is_xpu_flag: + return False + try: + import intel_extension_for_pytorch as ipex # noqa: F401 + _import_ipex = True + except ImportError as e: + logger.warning("Import Error for IPEX: %s", e.msg) + _import_ipex = False + # ipex dependency is not ready + if not _import_ipex: + logger.warning("not found ipex lib") + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() + + @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" @@ -482,6 +502,9 @@ def is_pin_memory_available() -> bool: print_warning_once("Using 'pin_memory=False' as WSL is detected. " "This may slow down the performance.") return False + elif is_xpu(): + print_warning_once("Pin memory is not supported on XPU.") + return False elif is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False @@ -497,8 +520,12 @@ def __init__(self, device: Optional[torch.types.Device] = None): def current_memory_usage(self) -> float: # Return the memory usage in bytes. - torch.cuda.reset_peak_memory_stats(self.device) - mem = torch.cuda.max_memory_allocated(self.device) + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats(self.device) + mem = torch.cuda.max_memory_allocated(self.device) + elif is_xpu(): + torch.xpu.reset_peak_memory_stats(self.device) + mem = torch.xpu.max_memory_allocated(self.device) return mem def __enter__(self): diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 341b177d4af2a..fbd1343fea19c 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -4,7 +4,7 @@ import torch from vllm.attention import get_attn_backend -from vllm.config import CacheConfig, ModelConfig, ParallelConfig +from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, is_pin_memory_available) @@ -25,10 +25,12 @@ def __init__( cache_config: CacheConfig, model_config: ModelConfig, parallel_config: ParallelConfig, + device_config: DeviceConfig, ) -> None: self.cache_config = cache_config self.model_config = model_config self.parallel_config = parallel_config + self.device_config = device_config self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) @@ -55,7 +57,8 @@ def __init__( ) # Initialize the cache. - self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") + self.gpu_cache = self._allocate_kv_cache( + self.num_gpu_blocks, self.device_config.device_type) self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") def _allocate_kv_cache( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7a378a862d0c0..f9b8a065a8b24 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -205,7 +205,8 @@ def initialize_cache(self, num_gpu_blocks: int, def _init_cache_engine(self): assert self.cache_config.num_gpu_blocks is not None self.cache_engine = CacheEngine(self.cache_config, self.model_config, - self.parallel_config) + self.parallel_config, + self.device_config) self.gpu_cache = self.cache_engine.gpu_cache def _warm_up_model(self) -> None: diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py new file mode 100644 index 0000000000000..f30de703e805d --- /dev/null +++ b/vllm/worker/xpu_model_runner.py @@ -0,0 +1,417 @@ +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.attention import get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model +from vllm.sampling_params import SamplingParams +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad +from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 +_BATCH_SIZE_ALIGNMENT = 8 +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) +] + + +class XPUModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.lora_config = lora_config + self.load_config = load_config + self.cache_config = cache_config + self.vision_language_config = vision_language_config + self.is_driver_worker = is_driver_worker + + self.sliding_window = model_config.get_sliding_window() + self.device_config = device_config + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.block_size = cache_config.block_size + self.max_context_len_to_capture = ( + self.model_config.max_context_len_to_capture + if self.model_config is not None else 0) + + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + def load_model(self) -> None: + with CudaMemoryProfiler() as m: + self.model = get_model( + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config, + ) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for vision encoding, which needs + # to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + + seq_data = SequenceData([0] * seq_len) + dummy_multi_modal_data = None + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=None, + multi_modal_data=dummy_multi_modal_data, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [None] * num_layers + self.execute_model(seqs, kv_caches) + torch.xpu.synchronize() + return + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, + Optional[torch.Tensor]]: + multi_modal_input = None + if self.is_driver_worker: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_input + ) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode(seq_group_metadata_list) + seq_lens = [] + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + # subquery_lens is not needed if chunked prefill is not + # supported. Since CPU worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens, + self.device, + pin_memory=False) + # Broadcast the metadata. + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "selected_token_indices": + sampling_metadata.selected_token_indices, + } + metadata_dict.update(attn_metadata.asdict_zerocopy()) + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + selected_token_indices = metadata_dict.pop( + "selected_token_indices") + attn_metadata = self.attn_backend.make_metadata(**metadata_dict) + sampling_metadata = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + + return (input_tokens, input_positions, attn_metadata, + sampling_metadata, multi_modal_input) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append(position) + + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + seq_lens.append(seq_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + max_decode_seq_len = max(seq_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seqlen_q=None, + max_seqlen=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_seq_len=max_decode_seq_len, + num_prefill_tokens=0, + num_decode_tokens=len(input_tokens), + num_prefills=0, + block_tables=block_tables, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[torch.Tensor], + ) -> Optional[SamplerOutput]: + (input_tokens, input_positions, attn_metadata, sampling_metadata, + multi_modal_input + ) = self.prepare_input_tensors(seq_group_metadata_list) + + model_executable = self.model + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return None + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return output + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + Optional[torch.Tensor]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + multi_modal_input_list: List[torch.Tensor] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + seq_len = len(prompt_tokens) + + seq_lens.append(seq_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(list(range(computed_len, seq_len))) + + if seq_group_metadata.multi_modal_data: + multi_modal_input_list.append( + seq_group_metadata.multi_modal_data.data) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seq_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, seq_len - self.sliding_window) + + for i in range(computed_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if multi_modal_input_list: + assert self.vision_language_config, ( + "Multi-modal inputs are only supported by " + "vision language models.") + multi_modal_input = torch.cat(multi_modal_input_list, + dim=0).to(self.device) + else: + multi_modal_input = None + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + + max_seqlen = max(seq_lens) + tmp = [0] + tmp.extend(seq_lens) + seqlen = torch.tensor(tmp) + seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seqlen_q=seqlen_q, + max_seqlen=max_seqlen, + seq_lens_tensor=None, + max_decode_seq_len=None, + num_prefills=len(seq_lens), + num_prefill_tokens=num_prompt_tokens, + num_decode_tokens=0, + block_tables=torch.tensor([], device=self.device, dtype=torch.int), + ) + return (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_input) diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py new file mode 100644 index 0000000000000..773ee9f8159e1 --- /dev/null +++ b/vllm/worker/xpu_worker.py @@ -0,0 +1,193 @@ +"""A XPU worker class.""" +import gc +import os +from typing import List, Optional, Tuple + +import intel_extension_for_pytorch # noqa: F401 +import oneccl_bindings_for_pytorch # noqa: F401 +import torch +import torch.distributed + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.utils import is_xpu +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.worker import Worker +from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.xpu_model_runner import XPUModelRunner + +logger = init_logger(__name__) + + +class XPUWorker(LoraNotSupportedWorkerBase, Worker): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single XPU device. The worker is + responsible for maintaining the KV cache and executing the model on the + XPU. In case of distributed inference, each worker is assigned a partition + of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + vision_language_config: Optional[VisionLanguageConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + is_driver_worker: bool = False, + ) -> None: + assert device_config.device_type == "xpu" + assert is_xpu() + + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + self.vision_language_config = vision_language_config + if self.vision_language_config: + assert not self.lora_config, ( + "To be tested: vision language model with LoRA settings.") + + self.model_runner = XPUModelRunner( # type: ignore + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=self.load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + vision_language_config=vision_language_config, + ) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: CacheEngine + self.gpu_cache: List[torch.Tensor] + + def init_device(self) -> None: + if self.device_config.device.type == "xpu" and is_xpu(): + self.device = torch.device(f"xpu:{self.local_rank}") + torch.xpu.set_device(self.device) + torch.xpu.empty_cache() + self.init_gpu_memory = torch.xpu.get_device_properties( + self.local_rank).total_memory + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + self.init_worker_distributed_environment() + # Initialize the model. + set_random_seed(self.model_config.seed) + + # keep this method for `empty_cache` and `synchronize` api + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.xpu.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.xpu.synchronize() + used_memory = torch.xpu.memory_allocated() + total_gpu_memory = torch.xpu.get_device_properties( + self.local_rank).total_memory + free_gpu_memory = total_gpu_memory - used_memory + + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_gpu_memory - free_gpu_memory + assert peak_memory > 0, ( + "Error in memory profiling. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = self.get_cache_block_size_bytes() + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + gc.collect() + torch.xpu.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def _warm_up_model(self) -> None: + # IPEX don't support capture graph yet + pass + + def init_worker_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch " + "world size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + # use sockets as default Level zero IPC exchange backend. By + # default oneccl will use `drmfd` as mechanism which need extra + # dependency (libdrm and drm headers) on your system. + ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", + "sockets") + os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=self.local_rank, + backend="ccl") + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size)