forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Hardware][Intel] OpenVINO vLLM backend (vllm-project#5379)
- Loading branch information
1 parent
cd3f6ee
commit af2e655
Showing
22 changed files
with
1,393 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# This script build the OpenVINO docker image and run the offline inference inside the container. | ||
# It serves a sanity check for compilation and basic model usage. | ||
set -ex | ||
|
||
# Try building the docker image | ||
docker build -t openvino-test -f Dockerfile.openvino . | ||
|
||
# Setup cleanup | ||
remove_docker_container() { docker rm -f openvino-test || true; } | ||
trap remove_docker_container EXIT | ||
remove_docker_container | ||
|
||
# Run the image and launch offline inference | ||
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used | ||
# to run the OpenAI compatible server. | ||
|
||
FROM ubuntu:22.04 AS dev | ||
|
||
RUN apt-get update -y && \ | ||
apt-get install -y python3-pip git | ||
WORKDIR /workspace | ||
|
||
# copy requirements | ||
COPY requirements-build.txt /workspace/vllm/ | ||
COPY requirements-common.txt /workspace/vllm/ | ||
COPY requirements-openvino.txt /workspace/vllm/ | ||
|
||
COPY vllm/ /workspace/vllm/vllm | ||
COPY setup.py /workspace/vllm/ | ||
|
||
# install build requirements | ||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt | ||
# build vLLM with OpenVINO backend | ||
RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ | ||
|
||
COPY examples/ /workspace/vllm/examples | ||
COPY benchmarks/ /workspace/vllm/benchmarks | ||
|
||
CMD ["/bin/bash"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
.. _installation_openvino: | ||
|
||
Installation with OpenVINO | ||
========================== | ||
|
||
vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features: | ||
|
||
- Prefix caching (``--enable-prefix-caching``) | ||
- Chunked prefill (``--enable-chunked-prefill``) | ||
|
||
**Table of contents**: | ||
|
||
- :ref:`Requirements <openvino_backend_requirements>` | ||
- :ref:`Quick start using Dockerfile <openvino_backend_quick_start_dockerfile>` | ||
- :ref:`Build from source <install_openvino_backend_from_source>` | ||
- :ref:`Performance tips <openvino_backend_performance_tips>` | ||
- :ref:`Limitations <openvino_backend_limitations>` | ||
|
||
.. _openvino_backend_requirements: | ||
|
||
Requirements | ||
------------ | ||
|
||
* OS: Linux | ||
* Instruction set architecture (ISA) requirement: at least AVX2. | ||
|
||
.. _openvino_backend_quick_start_dockerfile: | ||
|
||
Quick start using Dockerfile | ||
---------------------------- | ||
|
||
.. code-block:: console | ||
$ docker build -f Dockerfile.openvino -t vllm-openvino-env . | ||
$ docker run -it --rm vllm-openvino-env | ||
.. _install_openvino_backend_from_source: | ||
|
||
Install from source | ||
------------------- | ||
|
||
- First, install Python. For example, on Ubuntu 22.04, you can run: | ||
|
||
.. code-block:: console | ||
$ sudo apt-get update -y | ||
$ sudo apt-get install python3 | ||
- Second, install prerequisites vLLM OpenVINO backend installation: | ||
|
||
.. code-block:: console | ||
$ pip install --upgrade pip | ||
$ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu | ||
- Finally, install vLLM with OpenVINO backend: | ||
|
||
.. code-block:: console | ||
$ PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE=openvino python -m pip install -v . | ||
.. _openvino_backend_performance_tips: | ||
|
||
Performance tips | ||
---------------- | ||
|
||
vLLM OpenVINO backend uses the following environment variables to control behavior: | ||
|
||
- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. | ||
|
||
- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform. | ||
|
||
- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. | ||
|
||
To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``) | ||
|
||
OpenVINO best known configuration is: | ||
|
||
.. code-block:: console | ||
$ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \ | ||
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256 | ||
.. _openvino_backend_limitations: | ||
|
||
Limitations | ||
----------- | ||
|
||
- LoRA serving is not supported. | ||
|
||
- Only LLM models are currently supported. LLaVa and encoder-decoder models are not currently enabled in vLLM OpenVINO integration. | ||
|
||
- Tensor and pipeline parallelism are not currently enabled in vLLM integration. | ||
|
||
- Speculative sampling is not tested within vLLM integration. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Common dependencies | ||
-r requirements-common.txt | ||
|
||
# OpenVINO dependencies | ||
torch >= 2.1.2 | ||
openvino ~= 2024.3.0.dev | ||
optimum-intel[openvino] >= 1.17.2 | ||
|
||
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from dataclasses import dataclass | ||
from typing import List, Tuple | ||
|
||
import openvino as ov | ||
import torch | ||
|
||
from vllm.attention.backends.abstract import (AttentionBackend, | ||
AttentionMetadata) | ||
|
||
|
||
class OpenVINOAttentionBackend(AttentionBackend): | ||
|
||
@staticmethod | ||
def get_name() -> str: | ||
return "openvino" | ||
|
||
@staticmethod | ||
def get_impl_cls(): | ||
# OpenVINO implements PagedAttention as part of the Optimum | ||
# exported model | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def make_metadata(*args, **kwargs) -> "AttentionMetadata": | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata": | ||
return OpenVINOAttentionMetadata(*args, **kwargs) | ||
|
||
@staticmethod | ||
def get_kv_cache_shape( | ||
num_blocks: int, | ||
block_size: int, | ||
num_kv_heads: int, | ||
head_size: int, | ||
) -> Tuple[int, ...]: | ||
return (2, num_blocks, num_kv_heads, block_size, head_size) | ||
|
||
@staticmethod | ||
def swap_blocks( | ||
src_kv_cache: ov.Tensor, | ||
dst_kv_cache: ov.Tensor, | ||
src_to_dst: torch.Tensor, | ||
) -> None: | ||
# OpenVINO currently supports only CPU, which does not require | ||
# swap of KV cache blocks | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def copy_blocks( | ||
kv_caches: List[Tuple[ov.Tensor, ov.Tensor]], | ||
src_to_dists: List[Tuple[int, int]], | ||
) -> None: | ||
for src, dst in src_to_dists: | ||
for key_cache, value_cache in kv_caches: | ||
key_cache.data[dst, :] = key_cache.data[src, :] | ||
value_cache.data[dst, :] = value_cache.data[src, :] | ||
|
||
|
||
@dataclass | ||
class OpenVINOAttentionMetadata: | ||
"""Metadata for OpenVINOAttentionBackend. | ||
Basic terms used below: | ||
- batch_size_in_sequences - total number of sequences to execute | ||
- prompt_lens – per sequence size number of scheduled tokens | ||
- batch_size_in_tokens = sum(prompt_lens) | ||
- max_context_len = max(context_lens) | ||
- max_num_blocks = div_up(max_context_len / BLOCK_SIZE) | ||
- num_blocks – total number of blocks in block_indices | ||
""" | ||
|
||
# Describes past KV cache size for each sequence within a batch | ||
# Shape: [batch_size_in_sequences] | ||
# Type: i32 | ||
past_lens: torch.Tensor | ||
|
||
# Describes start indices of input / speculative tokens from | ||
# current sequences within a batch sequence | ||
# Shape: [batch_size_in_sequences + 1] | ||
# Type: i32 | ||
subsequence_begins: torch.Tensor | ||
|
||
# Describes block tables for each sequence within a batch - | ||
# indices along 0th dimension in key_cache and value_cache inputs | ||
# Shape: [num_blocks] | ||
# Type: i32 | ||
block_indices: torch.Tensor | ||
|
||
# Describes block tables for each sequence within a batch - | ||
# for i-th element, it is an index in block_indices with the | ||
# first block belonging to i-th sequence | ||
# Shape: [batch_size_in_sequences + 1] | ||
# Type: i32 | ||
block_indices_begins: torch.Tensor | ||
|
||
# Describes max context length | ||
# Shape: scalar | ||
# Type: i32 | ||
max_context_len: torch.Tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.