Skip to content

Commit

Permalink
[Hardware][Intel] OpenVINO vLLM backend (vllm-project#5379)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov authored Jun 28, 2024
1 parent 5932634 commit 57f09a4
Show file tree
Hide file tree
Showing 22 changed files with 1,393 additions and 23 deletions.
14 changes: 14 additions & 0 deletions .buildkite/run-openvino-test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# This script build the OpenVINO docker image and run the offline inference inside the container.
# It serves a sanity check for compilation and basic model usage.
set -ex

# Try building the docker image
docker build -t openvino-test -f Dockerfile.openvino .

# Setup cleanup
remove_docker_container() { docker rm -f openvino-test || true; }
trap remove_docker_container EXIT
remove_docker_container

# Run the image and launch offline inference
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py
26 changes: 26 additions & 0 deletions Dockerfile.openvino
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
# to run the OpenAI compatible server.

FROM ubuntu:22.04 AS dev

RUN apt-get update -y && \
apt-get install -y python3-pip git
WORKDIR /workspace

# copy requirements
COPY requirements-build.txt /workspace/vllm/
COPY requirements-common.txt /workspace/vllm/
COPY requirements-openvino.txt /workspace/vllm/

COPY vllm/ /workspace/vllm/vllm
COPY setup.py /workspace/vllm/

# install build requirements
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
# build vLLM with OpenVINO backend
RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/

COPY examples/ /workspace/vllm/examples
COPY benchmarks/ /workspace/vllm/benchmarks

CMD ["/bin/bash"]
7 changes: 4 additions & 3 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu", "tpu", "xpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
default="auto",
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
'CPU.')
parser.add_argument('--block-size',
type=int,
default=16,
Expand Down
7 changes: 4 additions & 3 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,10 @@ def main(args: argparse.Namespace):
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu", "tpu", "xpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
default="auto",
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
'CPU.')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
Expand Down
95 changes: 95 additions & 0 deletions docs/source/getting_started/openvino-installation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
.. _installation_openvino:

Installation with OpenVINO
==========================

vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features:

- Prefix caching (``--enable-prefix-caching``)
- Chunked prefill (``--enable-chunked-prefill``)

**Table of contents**:

- :ref:`Requirements <openvino_backend_requirements>`
- :ref:`Quick start using Dockerfile <openvino_backend_quick_start_dockerfile>`
- :ref:`Build from source <install_openvino_backend_from_source>`
- :ref:`Performance tips <openvino_backend_performance_tips>`
- :ref:`Limitations <openvino_backend_limitations>`

.. _openvino_backend_requirements:

Requirements
------------

* OS: Linux
* Instruction set architecture (ISA) requirement: at least AVX2.

.. _openvino_backend_quick_start_dockerfile:

Quick start using Dockerfile
----------------------------

.. code-block:: console
$ docker build -f Dockerfile.openvino -t vllm-openvino-env .
$ docker run -it --rm vllm-openvino-env
.. _install_openvino_backend_from_source:

Install from source
-------------------

- First, install Python. For example, on Ubuntu 22.04, you can run:

.. code-block:: console
$ sudo apt-get update -y
$ sudo apt-get install python3
- Second, install prerequisites vLLM OpenVINO backend installation:

.. code-block:: console
$ pip install --upgrade pip
$ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu
- Finally, install vLLM with OpenVINO backend:

.. code-block:: console
$ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
.. _openvino_backend_performance_tips:

Performance tips
----------------

vLLM OpenVINO backend uses the following environment variables to control behavior:

- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.

- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform.

- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off.

To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``)

OpenVINO best known configuration is:

.. code-block:: console
$ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256
.. _openvino_backend_limitations:

Limitations
-----------

- LoRA serving is not supported.

- Only LLM models are currently supported. LLaVa and encoder-decoder models are not currently enabled in vLLM OpenVINO integration.

- Tensor and pipeline parallelism are not currently enabled in vLLM integration.

- Speculative sampling is not tested within vLLM integration.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Documentation

getting_started/installation
getting_started/amd-installation
getting_started/openvino-installation
getting_started/cpu-installation
getting_started/neuron-installation
getting_started/tpu-installation
Expand Down
9 changes: 9 additions & 0 deletions requirements-openvino.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Common dependencies
-r requirements-common.txt

# OpenVINO dependencies
torch >= 2.1.2
openvino ~= 2024.3.0.dev
optimum-intel[openvino] >= 1.17.2

triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
11 changes: 10 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ def _is_cpu() -> bool:
return VLLM_TARGET_DEVICE == "cpu"


def _is_openvino() -> bool:
return VLLM_TARGET_DEVICE == "openvino"


def _is_xpu() -> bool:
return VLLM_TARGET_DEVICE == "xpu"

Expand Down Expand Up @@ -337,6 +341,8 @@ def get_vllm_version() -> str:
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"+neuron{neuron_version_str}"
elif _is_openvino():
version += "+openvino"
elif _is_tpu():
version += "+tpu"
elif _is_cpu():
Expand Down Expand Up @@ -388,6 +394,8 @@ def _read_requirements(filename: str) -> List[str]:
requirements = _read_requirements("requirements-rocm.txt")
elif _is_neuron():
requirements = _read_requirements("requirements-neuron.txt")
elif _is_openvino():
requirements = _read_requirements("requirements-openvino.txt")
elif _is_tpu():
requirements = _read_requirements("requirements-tpu.txt")
elif _is_cpu():
Expand All @@ -396,7 +404,8 @@ def _read_requirements(filename: str) -> List[str]:
requirements = _read_requirements("requirements-xpu.txt")
else:
raise ValueError(
"Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.")
"Unsupported platform, please use CUDA, ROCm, Neuron, "
"OpenVINO, or CPU.")
return requirements


Expand Down
9 changes: 7 additions & 2 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@


@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
Expand All @@ -28,6 +28,11 @@ def test_env(name: str, device: str, monkeypatch):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(8, 16, 8, None, torch.float16,
torch.float16, 16)
Expand Down
101 changes: 101 additions & 0 deletions vllm/attention/backends/openvino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from dataclasses import dataclass
from typing import List, Tuple

import openvino as ov
import torch

from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)


class OpenVINOAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "openvino"

@staticmethod
def get_impl_cls():
# OpenVINO implements PagedAttention as part of the Optimum
# exported model
raise NotImplementedError

@staticmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError

@staticmethod
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
return OpenVINOAttentionMetadata(*args, **kwargs)

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, num_kv_heads, block_size, head_size)

@staticmethod
def swap_blocks(
src_kv_cache: ov.Tensor,
dst_kv_cache: ov.Tensor,
src_to_dst: torch.Tensor,
) -> None:
# OpenVINO currently supports only CPU, which does not require
# swap of KV cache blocks
raise NotImplementedError

@staticmethod
def copy_blocks(
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]],
src_to_dists: List[Tuple[int, int]],
) -> None:
for src, dst in src_to_dists:
for key_cache, value_cache in kv_caches:
key_cache.data[dst, :] = key_cache.data[src, :]
value_cache.data[dst, :] = value_cache.data[src, :]


@dataclass
class OpenVINOAttentionMetadata:
"""Metadata for OpenVINOAttentionBackend.
Basic terms used below:
- batch_size_in_sequences - total number of sequences to execute​
- prompt_lens – per sequence size number of scheduled tokens​
- batch_size_in_tokens = sum(prompt_lens)​
- max_context_len = max(context_lens)​
- max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​
- num_blocks – total number of blocks in block_indices​
"""

# Describes past KV cache size for each sequence within a batch
# Shape: [batch_size_in_sequences]
# Type: i32​
past_lens: torch.Tensor

# Describes start indices of input / speculative tokens from
# current sequences within a batch sequence​
# Shape: [batch_size_in_sequences + 1]​
# Type: i32
subsequence_begins: torch.Tensor

# Describes block tables for each sequence within a batch​ -
# indices along 0th dimension in key_cache and value_cache inputs​
# Shape: [num_blocks]
# Type: i32​
block_indices: torch.Tensor

# Describes block tables for each sequence within a batch​ -
# for i-th element, it is an index in block_indices with the
# first block belonging to i-th sequence​
# Shape: [batch_size_in_sequences + 1]
# Type: i32​
block_indices_begins: torch.Tensor

# Describes max context length
# Shape: scalar
# Type: i32
max_context_len: torch.Tensor
12 changes: 11 additions & 1 deletion vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu

logger = init_logger(__name__)

Expand All @@ -17,6 +17,7 @@ class _Backend(enum.Enum):
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
Expand Down Expand Up @@ -64,6 +65,10 @@ def get_attn_backend(
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend
elif backend == _Backend.OPENVINO:
logger.info("Using OpenVINO Attention backend.")
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
return OpenVINOAttentionBackend
elif backend == _Backend.IPEX:
assert is_xpu(), RuntimeError(
"IPEX attention backend is only used for the XPU device.")
Expand Down Expand Up @@ -113,6 +118,11 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA

if is_openvino():
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO

if is_xpu():
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
Expand Down
Loading

0 comments on commit 57f09a4

Please sign in to comment.