From af2e6557f78b4293c6d39b927308720efc26d9d5 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Fri, 28 Jun 2024 17:50:16 +0400 Subject: [PATCH] [Hardware][Intel] OpenVINO vLLM backend (#5379) --- .buildkite/run-openvino-test.sh | 14 + Dockerfile.openvino | 26 ++ benchmarks/benchmark_latency.py | 7 +- benchmarks/benchmark_throughput.py | 7 +- .../getting_started/openvino-installation.rst | 95 +++++ docs/source/index.rst | 1 + requirements-openvino.txt | 9 + setup.py | 11 +- tests/kernels/test_attention_selector.py | 9 +- vllm/attention/backends/openvino.py | 101 +++++ vllm/attention/selector.py | 12 +- vllm/config.py | 8 +- vllm/engine/arg_utils.py | 14 +- vllm/engine/async_llm_engine.py | 6 + vllm/engine/llm_engine.py | 3 + vllm/envs.py | 22 +- vllm/executor/openvino_executor.py | 163 ++++++++ vllm/model_executor/layers/sampler.py | 4 +- vllm/model_executor/model_loader/openvino.py | 210 +++++++++++ vllm/utils.py | 11 +- vllm/worker/openvino_model_runner.py | 330 ++++++++++++++++ vllm/worker/openvino_worker.py | 353 ++++++++++++++++++ 22 files changed, 1393 insertions(+), 23 deletions(-) create mode 100755 .buildkite/run-openvino-test.sh create mode 100644 Dockerfile.openvino create mode 100644 docs/source/getting_started/openvino-installation.rst create mode 100644 requirements-openvino.txt create mode 100644 vllm/attention/backends/openvino.py create mode 100644 vllm/executor/openvino_executor.py create mode 100644 vllm/model_executor/model_loader/openvino.py create mode 100644 vllm/worker/openvino_model_runner.py create mode 100644 vllm/worker/openvino_worker.py diff --git a/.buildkite/run-openvino-test.sh b/.buildkite/run-openvino-test.sh new file mode 100755 index 0000000000000..70e56596c4a86 --- /dev/null +++ b/.buildkite/run-openvino-test.sh @@ -0,0 +1,14 @@ +# This script build the OpenVINO docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t openvino-test -f Dockerfile.openvino . + +# Setup cleanup +remove_docker_container() { docker rm -f openvino-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image and launch offline inference +docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py diff --git a/Dockerfile.openvino b/Dockerfile.openvino new file mode 100644 index 0000000000000..9861997b451a9 --- /dev/null +++ b/Dockerfile.openvino @@ -0,0 +1,26 @@ +# The vLLM Dockerfile is used to construct vLLM image that can be directly used +# to run the OpenAI compatible server. + +FROM ubuntu:22.04 AS dev + +RUN apt-get update -y && \ + apt-get install -y python3-pip git +WORKDIR /workspace + +# copy requirements +COPY requirements-build.txt /workspace/vllm/ +COPY requirements-common.txt /workspace/vllm/ +COPY requirements-openvino.txt /workspace/vllm/ + +COPY vllm/ /workspace/vllm/vllm +COPY setup.py /workspace/vllm/ + +# install build requirements +RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt +# build vLLM with OpenVINO backend +RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ + +COPY examples/ /workspace/vllm/examples +COPY benchmarks/ /workspace/vllm/benchmarks + +CMD ["/bin/bash"] diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index f3d00e456f159..a46ee15817f4c 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -207,9 +207,10 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument( "--device", type=str, - default="cuda", - choices=["cuda", "cpu", "tpu", "xpu"], - help='device type for vLLM execution, supporting CUDA and CPU.') + default="auto", + choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], + help='device type for vLLM execution, supporting CUDA, OpenVINO and ' + 'CPU.') parser.add_argument('--block-size', type=int, default=16, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 2c6beb4e89672..a52e67bbbe7e3 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -349,9 +349,10 @@ def main(args: argparse.Namespace): parser.add_argument( "--device", type=str, - default="cuda", - choices=["cuda", "cpu", "tpu", "xpu"], - help='device type for vLLM execution, supporting CUDA and CPU.') + default="auto", + choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], + help='device type for vLLM execution, supporting CUDA, OpenVINO and ' + 'CPU.') parser.add_argument( "--enable-prefix-caching", action='store_true', diff --git a/docs/source/getting_started/openvino-installation.rst b/docs/source/getting_started/openvino-installation.rst new file mode 100644 index 0000000000000..0d8e0b680ff0d --- /dev/null +++ b/docs/source/getting_started/openvino-installation.rst @@ -0,0 +1,95 @@ +.. _installation_openvino: + +Installation with OpenVINO +========================== + +vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features: + +- Prefix caching (``--enable-prefix-caching``) +- Chunked prefill (``--enable-chunked-prefill``) + +**Table of contents**: + +- :ref:`Requirements ` +- :ref:`Quick start using Dockerfile ` +- :ref:`Build from source ` +- :ref:`Performance tips ` +- :ref:`Limitations ` + +.. _openvino_backend_requirements: + +Requirements +------------ + +* OS: Linux +* Instruction set architecture (ISA) requirement: at least AVX2. + +.. _openvino_backend_quick_start_dockerfile: + +Quick start using Dockerfile +---------------------------- + +.. code-block:: console + + $ docker build -f Dockerfile.openvino -t vllm-openvino-env . + $ docker run -it --rm vllm-openvino-env + +.. _install_openvino_backend_from_source: + +Install from source +------------------- + +- First, install Python. For example, on Ubuntu 22.04, you can run: + + .. code-block:: console + + $ sudo apt-get update -y + $ sudo apt-get install python3 + +- Second, install prerequisites vLLM OpenVINO backend installation: + + .. code-block:: console + + $ pip install --upgrade pip + $ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu + +- Finally, install vLLM with OpenVINO backend: + + .. code-block:: console + + $ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python -m pip install -v . + +.. _openvino_backend_performance_tips: + +Performance tips +---------------- + +vLLM OpenVINO backend uses the following environment variables to control behavior: + +- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. + +- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform. + +- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. + +To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``) + +OpenVINO best known configuration is: + +.. code-block:: console + + $ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \ + python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256 + +.. _openvino_backend_limitations: + +Limitations +----------- + +- LoRA serving is not supported. + +- Only LLM models are currently supported. LLaVa and encoder-decoder models are not currently enabled in vLLM OpenVINO integration. + +- Tensor and pipeline parallelism are not currently enabled in vLLM integration. + +- Speculative sampling is not tested within vLLM integration. diff --git a/docs/source/index.rst b/docs/source/index.rst index 3a9f5a3d81e84..8fd25ce828839 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -63,6 +63,7 @@ Documentation getting_started/installation getting_started/amd-installation + getting_started/openvino-installation getting_started/cpu-installation getting_started/neuron-installation getting_started/tpu-installation diff --git a/requirements-openvino.txt b/requirements-openvino.txt new file mode 100644 index 0000000000000..e555d52572541 --- /dev/null +++ b/requirements-openvino.txt @@ -0,0 +1,9 @@ +# Common dependencies +-r requirements-common.txt + +# OpenVINO dependencies +torch >= 2.1.2 +openvino ~= 2024.3.0.dev +optimum-intel[openvino] >= 1.17.2 + +triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. diff --git a/setup.py b/setup.py index b2ae6def8cdc6..067ad13fed71b 100644 --- a/setup.py +++ b/setup.py @@ -233,6 +233,10 @@ def _is_cpu() -> bool: return VLLM_TARGET_DEVICE == "cpu" +def _is_openvino() -> bool: + return VLLM_TARGET_DEVICE == "openvino" + + def _is_xpu() -> bool: return VLLM_TARGET_DEVICE == "xpu" @@ -337,6 +341,8 @@ def get_vllm_version() -> str: if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"+neuron{neuron_version_str}" + elif _is_openvino(): + version += "+openvino" elif _is_tpu(): version += "+tpu" elif _is_cpu(): @@ -388,6 +394,8 @@ def _read_requirements(filename: str) -> List[str]: requirements = _read_requirements("requirements-rocm.txt") elif _is_neuron(): requirements = _read_requirements("requirements-neuron.txt") + elif _is_openvino(): + requirements = _read_requirements("requirements-openvino.txt") elif _is_tpu(): requirements = _read_requirements("requirements-tpu.txt") elif _is_cpu(): @@ -396,7 +404,8 @@ def _read_requirements(filename: str) -> List[str]: requirements = _read_requirements("requirements-xpu.txt") else: raise ValueError( - "Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.") + "Unsupported platform, please use CUDA, ROCm, Neuron, " + "OpenVINO, or CPU.") return requirements diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 79e03c7478de0..8e6c50666e70c 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -9,8 +9,8 @@ @pytest.mark.parametrize( - "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"]) -@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) + "name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"]) +@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"]) def test_env(name: str, device: str, monkeypatch): """Test that the attention selector can be set via environment variable. Note that we do not test FlashAttn because it is the default backend. @@ -28,6 +28,11 @@ def test_env(name: str, device: str, monkeypatch): backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) assert backend.name == "ROCM_FLASH" + elif device == "openvino": + with patch("vllm.attention.selector.is_openvino", return_value=True): + backend = which_attn_to_use(8, 16, 8, None, torch.float16, + torch.float16, 16) + assert backend.name == "OPENVINO" else: backend = which_attn_to_use(8, 16, 8, None, torch.float16, torch.float16, 16) diff --git a/vllm/attention/backends/openvino.py b/vllm/attention/backends/openvino.py new file mode 100644 index 0000000000000..0f21b50ad4dc7 --- /dev/null +++ b/vllm/attention/backends/openvino.py @@ -0,0 +1,101 @@ +from dataclasses import dataclass +from typing import List, Tuple + +import openvino as ov +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) + + +class OpenVINOAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "openvino" + + @staticmethod + def get_impl_cls(): + # OpenVINO implements PagedAttention as part of the Optimum + # exported model + raise NotImplementedError + + @staticmethod + def make_metadata(*args, **kwargs) -> "AttentionMetadata": + raise NotImplementedError + + @staticmethod + def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata": + return OpenVINOAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, num_kv_heads, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: ov.Tensor, + dst_kv_cache: ov.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + # OpenVINO currently supports only CPU, which does not require + # swap of KV cache blocks + raise NotImplementedError + + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[ov.Tensor, ov.Tensor]], + src_to_dists: List[Tuple[int, int]], + ) -> None: + for src, dst in src_to_dists: + for key_cache, value_cache in kv_caches: + key_cache.data[dst, :] = key_cache.data[src, :] + value_cache.data[dst, :] = value_cache.data[src, :] + + +@dataclass +class OpenVINOAttentionMetadata: + """Metadata for OpenVINOAttentionBackend. + + Basic terms used below: + - batch_size_in_sequences - total number of sequences to execute​ + - prompt_lens – per sequence size number of scheduled tokens​ + - batch_size_in_tokens = sum(prompt_lens)​ + - max_context_len = max(context_lens)​ + - max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​ + - num_blocks – total number of blocks in block_indices​ + """ + + # Describes past KV cache size for each sequence within a batch + # Shape: [batch_size_in_sequences] + # Type: i32​ + past_lens: torch.Tensor + + # Describes start indices of input / speculative tokens from + # current sequences within a batch sequence​ + # Shape: [batch_size_in_sequences + 1]​ + # Type: i32 + subsequence_begins: torch.Tensor + + # Describes block tables for each sequence within a batch​ - + # indices along 0th dimension in key_cache and value_cache inputs​ + # Shape: [num_blocks] + # Type: i32​ + block_indices: torch.Tensor + + # Describes block tables for each sequence within a batch​ - + # for i-th element, it is an index in block_indices with the + # first block belonging to i-th sequence​ + # Shape: [batch_size_in_sequences + 1] + # Type: i32​ + block_indices_begins: torch.Tensor + + # Describes max context length + # Shape: scalar + # Type: i32 + max_context_len: torch.Tensor diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 1d56d87ccd119..96f88bbf4cf59 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -7,7 +7,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger -from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu +from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu logger = init_logger(__name__) @@ -17,6 +17,7 @@ class _Backend(enum.Enum): XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() + OPENVINO = enum.auto() FLASHINFER = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() @@ -64,6 +65,10 @@ def get_attn_backend( logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend return TorchSDPABackend + elif backend == _Backend.OPENVINO: + logger.info("Using OpenVINO Attention backend.") + from vllm.attention.backends.openvino import OpenVINOAttentionBackend + return OpenVINOAttentionBackend elif backend == _Backend.IPEX: assert is_xpu(), RuntimeError( "IPEX attention backend is only used for the XPU device.") @@ -113,6 +118,11 @@ def which_attn_to_use( logger.info("Cannot use %s backend on CPU.", selected_backend) return _Backend.TORCH_SDPA + if is_openvino(): + if selected_backend != _Backend.OPENVINO: + logger.info("Cannot use %s backend on OpenVINO.", selected_backend) + return _Backend.OPENVINO + if is_xpu(): if selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) diff --git a/vllm/config.py b/vllm/config.py index 6adeaf4209570..31d30cfa73d1f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -14,8 +14,8 @@ from vllm.tracing import is_otel_installed from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, - is_hip, is_neuron, is_tpu, is_xpu, print_warning_once, - update_environment_variables) + is_hip, is_neuron, is_openvino, is_tpu, is_xpu, + print_warning_once, update_environment_variables) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -781,6 +781,8 @@ def __init__(self, device: str = "auto") -> None: # Automated device type detection if is_neuron(): self.device_type = "neuron" + elif is_openvino(): + self.device_type = "openvino" elif is_tpu(): self.device_type = "tpu" elif is_cpu(): @@ -796,7 +798,7 @@ def __init__(self, device: str = "auto") -> None: self.device_type = device # Some device types require processing inputs on CPU - if self.device_type in ["neuron"]: + if self.device_type in ["neuron", "openvino"]: self.device = torch.device("cpu") elif self.device_type in ["tpu"]: self.device = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c392155e8981b..f9d089091ffc2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -504,12 +504,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'Enabling this will use the fully sharded layers. ' 'At high sequence length, max rank or ' 'tensor parallel size, this is likely faster.')) - parser.add_argument( - "--device", - type=str, - default=EngineArgs.device, - choices=["auto", "cuda", "neuron", "cpu", "tpu", "xpu"], - help='Device type for vLLM execution.') + parser.add_argument("--device", + type=str, + default=EngineArgs.device, + choices=[ + "auto", "cuda", "neuron", "cpu", "openvino", + "tpu", "xpu" + ], + help='Device type for vLLM execution.') # Related to Vision-language models such as llava parser = EngineArgs.add_cli_args_for_vlm(parser) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 848e05f033a8e..7db3bb28c6ee5 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -393,6 +393,12 @@ def from_engine_args( "Distributed execution is not supported with the CPU backend.") from vllm.executor.cpu_executor import CPUExecutorAsync executor_class = CPUExecutorAsync + elif engine_config.device_config.device_type == "openvino": + assert distributed_executor_backend is None, ( + "Distributed execution is not supported with " + "the OpenVINO backend.") + from vllm.executor.openvino_executor import OpenVINOExecutorAsync + executor_class = OpenVINOExecutorAsync elif engine_config.device_config.device_type == "xpu": if distributed_executor_backend is None: from vllm.executor.xpu_executor import XPUExecutorAsync diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9b720d6138868..fde18f60e4ddd 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -363,6 +363,9 @@ def from_engine_args( elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor + elif engine_config.device_config.device_type == "openvino": + from vllm.executor.openvino_executor import OpenVINOExecutor + executor_class = OpenVINOExecutor elif engine_config.device_config.device_type == "xpu": if distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) diff --git a/vllm/envs.py b/vllm/envs.py index 49277e2d3519f..e8257535f1bf5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -28,6 +28,9 @@ VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_OPENVINO_KVCACHE_SPACE: int = 0 + VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None + VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/" VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "fork" @@ -49,7 +52,8 @@ # ================== Installation Time Env Vars ================== - # Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu] + # Target device of vLLM, supporting [cuda (by default), + # rocm, neuron, cpu, openvino] "VLLM_TARGET_DEVICE": lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"), @@ -208,6 +212,22 @@ "VLLM_CPU_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), + # OpenVINO key-value cache space + # default is 4GB + "VLLM_OPENVINO_KVCACHE_SPACE": + lambda: int(os.getenv("VLLM_OPENVINO_KVCACHE_SPACE", "0")), + + # OpenVINO KV cache precision + # default is bf16 if natively supported by platform, otherwise f16 + # To enable KV cache compression, please, explicitly specify u8 + "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION": + lambda: os.getenv("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION", None), + + # Enables weights compression during model export via HF Optimum + # default is False + "VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS": + lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)), + # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py new file mode 100644 index 0000000000000..8af375371f2f0 --- /dev/null +++ b/vllm/executor/openvino_executor.py @@ -0,0 +1,163 @@ +from typing import List, Set, Tuple + +import openvino as ov +import openvino.properties.hint as hints +import torch + +import vllm.envs as envs +from vllm.config import CacheConfig, ModelConfig +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) + +logger = init_logger(__name__) + + +class OpenVINOExecutor(ExecutorBase): + + def _init_executor(self) -> None: + assert self.device_config.device_type == "openvino" + assert self.lora_config is None, "OpenVINO backend doesn't support LoRA" + self.model_config = _verify_and_get_model_config(self.model_config) + self.cache_config = _verify_and_get_cache_config(self.cache_config) + + # Instantiate the worker and load the model to CPU. + self._init_worker() + + def _init_worker(self): + from vllm.worker.openvino_worker import OpenVINOWorker + + assert ( + self.parallel_config.world_size == 1 + ), "OpenVINOExecutor only supports single CPU socket currently." + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = OpenVINOWorker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=True, + ) + self.driver_worker.init_device() + self.driver_worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker.""" + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + # NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is + # referred as `gpu block`. Because we want to reuse the existing block + # management procedure. + logger.info("# CPU blocks: %d", num_gpu_blocks) + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.driver_worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.driver_worker.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.driver_worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.driver_worker.list_loras() + + def check_health(self) -> None: + # OpenVINOExecutor will always be healthy as long as + # it's running. + return + + +class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) + return output + + async def check_health_async(self) -> None: + # OpenVINOExecutor will always be healthy as long as + # it's running. + return + + +def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: + if config.dtype != torch.float32: + logger.warning( + f"Only float32 dtype is supported on OpenVINO, casting from {config.dtype}." # noqa: G004, E501 + ) + config.dtype = torch.float32 + if not config.enforce_eager: + logger.warning( + "CUDA graph is not supported on OpenVINO backend, fallback to the " + "eager mode.") + config.enforce_eager = True + return config + + +def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: + if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": + logger.info("KV cache type is overried to u8 via " + "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") + config.cache_dtype = ov.Type.u8 + else: + core = ov.Core() + inference_precision = core.get_property("CPU", + hints.inference_precision) + if inference_precision == ov.Type.bf16: + config.cache_dtype = ov.Type.bf16 + else: + config.cache_dtype = ov.Type.f16 + + if config.block_size != 32: + logger.info( + f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501 + ) + config.block_size = 32 + + kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE + if kv_cache_space >= 0: + _GB = 1 << 30 + if kv_cache_space == 0: + config.openvino_kvcache_space_bytes = 4 * _GB # type: ignore + logger.warning( + "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " + "for OpenVINO backend is not set, using 4 by default.") + else: + config.openvino_kvcache_space_bytes = kv_cache_space * _GB # type: ignore + else: + raise RuntimeError( + "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") + + return config diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index e07360a2fd682..6d00ea64f7cb8 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -679,7 +679,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. - Each element in the returned tensor represents the rank + Each element in the returned tensor represents the rank of the chosen token in the input logprob tensor. """ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), @@ -965,7 +965,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, distribution. - Greedy sampling performs `argmax` to obtain the token with the highest likelihood. - + Ignoring greedy sampling for a moment, we find that the computed probability distribution has the following property: we can sample from it independently and find that the token sampled by the Sampler has a frequency corresponding diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py new file mode 100644 index 0000000000000..5c522a61732a4 --- /dev/null +++ b/vllm/model_executor/model_loader/openvino.py @@ -0,0 +1,210 @@ +# ruff: noqa: SIM117 +from pathlib import Path +from typing import List, Optional, Tuple + +import openvino as ov +import torch +from huggingface_hub import HfApi +from openvino._offline_transformations import paged_attention_transformation +from optimum.intel import OVModelForCausalLM +from torch import nn + +import vllm.envs as envs +from vllm.attention.backends.openvino import OpenVINOAttentionMetadata +from vllm.config import DeviceConfig, ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import (LogitsProcessor, + _prune_hidden_states) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput + +logger = init_logger(__name__) + + +def _flattenize_inputs(inputs): + """ + Helper function for making nested inputs flattens + """ + flatten_inputs = [] + for input_data in inputs: + if input_data is None: + continue + if isinstance(input_data, (list, tuple)): + flatten_inputs.extend(_flattenize_inputs(input_data)) + elif isinstance(input_data, dict): + flatten_inputs.extend(_flattenize_inputs(list( + input_data.values()))) + else: + flatten_inputs.append(input_data) + return flatten_inputs + + +def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type, + is_cpu: bool): + # Apply hardware dependent modifications to KV tensors + for parameter in model.get_parameters(): + input = parameter.get_output_tensor(0) + input_names = input.get_names() + if len(input_names) != 1: + continue + input_name = next(iter(input_names)) + shape = parameter.get_partial_shape() + # use real block size if available, just a placeholder + # to provide the expected rank + x_size = 1 + num_blocks = ov.Dimension() + block_size = ov.Dimension() + head_size = ov.Dimension() + # TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD), + # pass more parameters to this function to set more static dimensions + if input_name.startswith("key_cache."): + cpu_shape = [num_blocks, shape[1], block_size, head_size] + gpu_shape = [ + num_blocks, + shape[1], + shape[2].get_length() // + x_size if shape[2].is_static else ov.Dimension(), + block_size, + x_size, + ] + elif input_name.startswith("value_cache."): + cpu_shape = [num_blocks, shape[1], block_size, head_size] + gpu_shape = [num_blocks, shape[1], shape[2], block_size] + else: + continue + parameter.set_partial_shape( + ov.PartialShape(cpu_shape if is_cpu else gpu_shape)) + parameter.set_element_type(kv_cache_dtype) + model.validate_nodes_and_infer_types() + + +def _require_model_export(model_id, revision=None, subfolder=None): + model_dir = Path(model_id) + if subfolder is not None: + model_dir = model_dir / subfolder + if model_dir.is_dir(): + return (not (model_dir / "openvino_model.xml").exists() + or not (model_dir / "openvino_model.bin").exists()) + + hf_api = HfApi() + try: + model_info = hf_api.model_info(model_id, revision=revision or "main") + normalized_subfolder = (None if subfolder is None else + Path(subfolder).as_posix()) + model_files = [ + file.rfilename for file in model_info.siblings + if normalized_subfolder is None + or file.rfilename.startswith(normalized_subfolder) + ] + ov_model_path = ("openvino_model.xml" if normalized_subfolder is None + else f"{normalized_subfolder}/openvino_model.xml") + return (ov_model_path not in model_files + or ov_model_path.replace(".xml", ".bin") not in model_files) + except Exception: + return True + + +class OpenVINOCasualLM(nn.Module): + + def __init__( + self, + model_config: ModelConfig, + device_config: DeviceConfig, + kv_cache_dtype: ov.Type, + ) -> None: + super().__init__() + self.logits_processor = LogitsProcessor( + model_config.hf_config.vocab_size, logits_as_input=True) + self.sampler = Sampler() + + export = _require_model_export(model_config.model) + if export: + logger.warning( + f"Provided model id {model_config.model} does not " # noqa: G004 + "contain OpenVINO IR, the model will be converted to IR with " + "default options. If you need to use specific options for " + "model conversion, use optimum-cli export openvino with " + "desired options.") + else: + logger.warning( + "OpenVINO IR is available for provided model id " # noqa: G004 + f"{model_config.model}. This IR will be used for inference " + "as-is, all possible options that may affect model conversion " + "are ignored.") + + load_in_8bit = envs.VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS + pt_model = OVModelForCausalLM.from_pretrained( + model_config.model, + export=export, + compile=False, + load_in_8bit=load_in_8bit, + trust_remote_code=model_config.trust_remote_code, + ) + + paged_attention_transformation(pt_model.model) + _modify_cache_parameters(pt_model.model, kv_cache_dtype, + device_config.device.type == "cpu") + + core = ov.Core() + ov_compiled = core.compile_model(pt_model.model, "CPU") + self.ov_request = ov_compiled.create_infer_request() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[Tuple[ov.Tensor, ov.Tensor]], + attn_metadata: OpenVINOAttentionMetadata, + ) -> torch.Tensor: + flatten_kv_cache = _flattenize_inputs(kv_caches) + + inputs = [ + input_ids, + positions, + *flatten_kv_cache, + attn_metadata.past_lens, + attn_metadata.subsequence_begins, + attn_metadata.block_indices, + attn_metadata.block_indices_begins, + attn_metadata.max_context_len, + ] + + self.ov_request.start_async(inputs, share_inputs=True) + self.ov_request.wait() + + logits = torch.from_numpy(self.ov_request.get_tensor("logits").data) + + # TODO: remove 'view' once OpenVINO PA will drop 'seq_len' dimension + return logits.view(-1, logits.shape[-1]) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + +def get_model( + model_config: ModelConfig, + device_config: DeviceConfig, + kv_cache_dtype: ov.Type, + **kwargs, +) -> torch.nn.Module: + lora_config = kwargs.get("lora_config", None) + if lora_config: + raise ValueError( + "OpenVINO modeling does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + + return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype) diff --git a/vllm/utils.py b/vllm/utils.py index 92abdb3fb9b14..6e8d4624cd713 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -176,6 +176,15 @@ def is_cpu() -> bool: return False +@lru_cache(maxsize=None) +def is_openvino() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "openvino" in version("vllm") + except PackageNotFoundError: + return False + + @lru_cache(maxsize=None) def is_neuron() -> bool: try: @@ -546,7 +555,7 @@ def is_pin_memory_available() -> bool: elif is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False - elif is_cpu(): + elif is_cpu() or is_openvino(): return False return True diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py new file mode 100644 index 0000000000000..336eaf814fb3f --- /dev/null +++ b/vllm/worker/openvino_model_runner.py @@ -0,0 +1,330 @@ +from typing import List, NamedTuple, Optional, Tuple + +import openvino as ov +import torch +from torch import nn + +from vllm.attention import get_attn_backend +from vllm.attention.backends.openvino import OpenVINOAttentionMetadata +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader.openvino import get_model +from vllm.sequence import SamplerOutput, SequenceGroupMetadata + +logger = init_logger(__name__) + + +class ModelInput(NamedTuple): + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: Optional[OpenVINOAttentionMetadata] + seq_lens: List[int] + query_lens: List[int] + multi_modal_input: Optional[torch.Tensor] + + @classmethod + def empty(cls, device): + return ModelInput(input_tokens=torch.empty(0, device=device), + input_positions=torch.empty(0, device=device), + attn_metadata=None, + seq_lens=[], + query_lens=[], + multi_modal_input=None) + + +class OpenVINOModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.vision_language_config = vision_language_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + def load_model(self) -> None: + self.model = get_model( + model_config=self.model_config, + device_config=self.device_config, + kv_cache_dtype=self.kv_cache_dtype, + ) + + def _prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> ModelInput: + """Prepare the model input based on a given sequence group. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + """ + input_tokens: List[int] = [] + input_positions: List[int] = [] + + seq_lens: List[int] = [] + past_lens: List[int] = [] + query_lens: List[int] = [] + subsequence_begins: List[int] = [] + block_indices: List[int] = [] + block_indices_begins: List[int] = [] + + # initialize beginning of prefix sums + subsequence_begins.append(0) + block_indices_begins.append(0) + + if len(seq_group_metadata_list) == 0: + return ModelInput.empty(self.device) + + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + is_prompt = seq_group_metadata.is_prompt + + for seq_id in seq_ids: + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + seq_data = seq_group_metadata.seq_data[seq_id] + if is_prompt: + computed_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + computed_len = seq_data.get_len() - 1 + + seq_len = min( + seq_data.get_len(), + computed_len + seq_group_metadata.token_chunk_size, + ) + if is_prompt: + tokens = seq_data.get_token_ids()[computed_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and is_prompt) + + block_table = seq_group_metadata.block_tables[seq_id] + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + computed_len = len(computed_block_nums) * self.block_size + tokens = tokens[computed_len:] + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + block_table = seq_group_metadata.block_tables[seq_id] + if self.sliding_window is not None: + # chunked prefill doesn't support sliding window. + assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501 + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # prompt phase w/o prefix_caching, chunked_prefill + pass + + block_indices.extend(block_table) + block_indices_begins.append(block_indices_begins[-1] + + len(block_table)) + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if self.sliding_window is not None and not is_prompt: + seq_len = min(seq_len, self.sliding_window) + computed_len = seq_len - 1 + + seq_lens.append(seq_len) + + query_len = seq_len - computed_len + query_lens.append(query_len) + + input_tokens.extend(tokens) + input_positions.extend(list(range(computed_len, seq_len))) + + past_lens.append(computed_len) + subsequence_begins.append(subsequence_begins[-1] + query_len) + + if is_prompt: + assert len(seq_ids) == 1 + else: + assert ( + query_len == 1 + ), "seq_len: {}, computed_len: {}, query_len: {}".format( + seq_len, computed_len, query_len) + + max_query_len = max(query_lens) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + + past_lens_tensor = torch.tensor(past_lens, + dtype=torch.int32, + device=self.device) # type: ignore + subsequence_begins_tensor = torch.tensor( + subsequence_begins, dtype=torch.int32, + device=self.device) # type: ignore + block_indices_tensor = torch.tensor(block_indices, + dtype=torch.int32, + device=self.device) # type: ignore + block_indices_begins_tensor = torch.tensor( + block_indices_begins, dtype=torch.int32, + device=self.device) # type: ignore + + max_context_len = max(seq_lens) + max_context_len_tensor = torch.tensor( + max_context_len, dtype=torch.int32, + device=self.device) # type: ignore + + attn_metadata = self.attn_backend.make_openvino_metadata( + past_lens=past_lens_tensor, + subsequence_begins=subsequence_begins_tensor, + block_indices=block_indices_tensor, + block_indices_begins=block_indices_begins_tensor, + max_context_len=max_context_len_tensor, + ) + return ModelInput( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + query_lens, + None, + ) + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata, + SamplingMetadata, Optional[torch.Tensor], ]: + multi_modal_input = None + + # Prepare input tensors. + ( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + query_lens, + multi_modal_input, + ) = self._prepare_model_input(seq_group_metadata_list) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens, + self.device, + pin_memory=False, + ) + + return ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + multi_modal_input, + ) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]], + ) -> Optional[SamplerOutput]: + ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + multi_modal_input, + ) = self.prepare_input_tensors(seq_group_metadata_list) + + model_executable = self.model + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "attn_metadata": attn_metadata, + } + if self.vision_language_config: + execute_model_kwargs.update({"image_input": multi_modal_input}) + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return output diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py new file mode 100644 index 0000000000000..7a462ce5d0b66 --- /dev/null +++ b/vllm/worker/openvino_worker.py @@ -0,0 +1,353 @@ +"""An OpenVINO worker class.""" +from typing import Any, Dict, List, Optional, Tuple + +import openvino as ov +import torch +import torch.distributed + +from vllm.attention import get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.distributed import (broadcast_tensor_dict, + ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.worker.openvino_model_runner import OpenVINOModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + +logger = init_logger(__name__) + + +class OpenVINOCacheEngine: + """Manages the KV cache for OpenVINO backend. + + This class is responsible for initializing and managing CPU KV + caches. It also provides methods for performing KV cache operations, such + as copying. + """ + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig, + ) -> None: + assert device_config.device_type == "openvino" + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + if device_config.device.type == "cpu" and \ + cache_config.cache_dtype == ov.Type.u8: + # Scale, zero point and quantized data will be stored together. + # The layout for per token per head: + # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501 + # so, we have to extend head_size by 8, which is sizeof(float) + # for scale and sizeof(float) for zeropoint + self.head_size += 8 + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks + # for OpenVINO backend, because we want to reuse KV cache management + # in the scheduler. + self.num_cpu_blocks = cache_config.num_gpu_blocks + + # Get attention backend. + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.head_size, + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + ) + + # Initialize the cache. + self.kv_cache: List[Tuple[ov.Tensor, + ov.Tensor]] = self._allocate_kv_cache( + self.num_cpu_blocks) + + def _allocate_kv_cache( + self, + num_blocks: int, + ) -> List[Tuple[ov.Tensor, ov.Tensor]]: + """Allocates KV cache.""" + k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:] + kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = [] + for _ in range(self.num_layers): + key_blocks = ov.Tensor(self.cache_config.cache_dtype, + k_block_shape) + value_blocks = ov.Tensor(self.cache_config.cache_dtype, + v_block_shape) + kv_cache.append((key_blocks, value_blocks)) + return kv_cache + + def swap_in(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError( + "Swap is not supported in OpenVINOCacheEngine.") + + def swap_out(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError( + "Swap is not supported in OpenVINOCacheEngine.") + + def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + block_size: int, + cache_dtype: ov.Type, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_kv_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + if cache_dtype == ov.Type.u8: + # Scale, zero point and quantized data will be stored together. + # The layout for per token per head: + # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501 + # so, we have to extend head_size by 8, which is sizeof(float) + # for scale and sizeof(float) for zeropoint + head_size += 8 + + key_cache_block = block_size * num_kv_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + dtype_size = cache_dtype.size + return dtype_size * total + + +class OpenVINOWorker(LoraNotSupportedWorkerBase): + """A worker class that executes the model on OpenVINO backend. + + Each worker is associated with a single OpenVINO device. The worker is + responsible for maintaining the KV cache and executing the model on the + OpenVINO backend. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + vision_language_config: Optional[VisionLanguageConfig] = None, + kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined, + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.vision_language_config = vision_language_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + + init_cached_hf_modules() + self.model_runner = OpenVINOModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=self.load_config, + lora_config=self.lora_config, + vision_language_config=self.vision_language_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + ) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: OpenVINOCacheEngine + self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] + + def init_device(self) -> None: + self.init_distributed_environment() + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of blocks available for the KV cache. + + This determines how many KV blocks can fit into the configured + KV cache space. + + Note that since vLLM assumes a block resides on GPU if it can be + modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. + This allows us to reuse the scheduler of vLLM without generalizing it + to different devices. + """ + # For OpenVINO backend, the block number will be calculated based on the + # openvino_kvcache_space_bytes. + cache_block_size = self.get_cache_block_size_bytes() + num_cpu_blocks = int(self.cache_config.openvino_kvcache_space_bytes // + cache_block_size) + num_cpu_blocks = max(num_cpu_blocks, 0) + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_gpu_blocks = num_cpu_blocks + num_cpu_blocks = 0 + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. Currently, swappable CPU memory is not + supported. + + Since this worker does not support GPUs, we use the num_gpu_blocks to + determine how many non-swappable CPU blocks to allocate. + """ + assert (num_cpu_blocks == 0 + ), f"{type(self)} does not support swappable cache" + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_cpu_blocks = num_gpu_blocks + + self._validate_num_cpu_blocks(num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_cpu_blocks + self.cache_config.num_cpu_blocks = 0 + + # Initialize the cache. + self._init_cache_engine() + + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: + """Raise errors if the num_cpu_blocks is invalid.""" + if num_cpu_blocks <= 0: + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_cpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_OPENVINO_KVCACHE_SPACE` or decreasing `max_model_len` " + "when initializing the engine.") + + def _init_cache_engine(self) -> None: + self.cache_engine = OpenVINOCacheEngine( + self.cache_config, + self.model_config, + self.parallel_config, + self.device_config, + ) + self.kv_cache = self.cache_engine.kv_cache + self.model_runner.block_size = self.cache_engine.block_size + + assert self.kv_cache is not None + + # Populate the cache to warmup the memory + for key_cache, value_cache in self.kv_cache: + key_cache.data[:] = 0 + value_cache.data[:] = 0 + + def cache_copy( + self, + blocks_to_copy: List[Tuple[int, int]], + ) -> None: + self.cache_engine.copy(blocks_to_copy) # type: ignore + + @torch.inference_mode() + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> List[SamplerOutput]: + if execute_model_req is None: + seq_group_metadata_list = None + else: + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups: int = len(seq_group_metadata_list) + assert execute_model_req is not None + blocks_to_copy = execute_model_req.blocks_to_copy + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 + data: Dict[str, Any] = { + "num_seq_groups": num_seq_groups, + "blocks_to_copy": execute_model_req.blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) + else: + data = broadcast_tensor_dict(src=0) + num_seq_groups = data["num_seq_groups"] + blocks_to_copy = data["blocks_to_copy"] + + self.cache_copy(blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return [] + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.kv_cache) + + # OpenVINO worker only supports single-step execution. + return [output] + + def init_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cpu()) + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + ) + + def get_cache_block_size_bytes(self) -> int: + """Return the size in bytes of a single KV cache block.""" + return OpenVINOCacheEngine.get_cache_block_size( + self.cache_config.block_size, + self.cache_config.cache_dtype, + self.model_config, + self.parallel_config, + )