Skip to content

Commit

Permalink
[Hardware][Intel] Optimize CPU backend and add more performance tips (v…
Browse files Browse the repository at this point in the history
…llm-project#4971)

Co-authored-by: Jianan Gu <jianan.gu@intel.com>
  • Loading branch information
bigPYJ1151 and jianan-gu authored Jun 13, 2024
1 parent e70a479 commit 5295b12
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 13 deletions.
8 changes: 6 additions & 2 deletions Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
FROM ubuntu:22.04 AS cpu-test-1

RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12

RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc

RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl

RUN pip install --upgrade pip \
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy

Expand All @@ -21,6 +25,6 @@ RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install

WORKDIR /workspace/

RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks

CMD ["/bin/bash"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ vLLM is flexible and easy to use with:
- Tensor parallelism support for distributed inference
- Streaming outputs
- OpenAI-compatible API server
- Support NVIDIA GPUs and AMD GPUs
- Support NVIDIA GPUs, AMD GPUs, and Intel CPUs
- (Experimental) Prefix caching support
- (Experimental) Multi-lora support

Expand Down
23 changes: 21 additions & 2 deletions docs/source/getting_started/cpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Table of contents:
#. :ref:`Requirements <cpu_backend_requirements>`
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
#. :ref:`Build from source <build_cpu_backend_from_source>`
#. :ref:`Intel Extension for PyTorch <ipex_guidance>`
#. :ref:`Performance tips <cpu_backend_performance_tips>`

.. _cpu_backend_requirements:
Expand All @@ -18,7 +19,7 @@ Requirements
------------

* OS: Linux
* Compiler: gcc/g++>=12.3.0 (recommended)
* Compiler: gcc/g++>=12.3.0 (optional, recommended)
* Instruction set architecture (ISA) requirement: AVX512 is required.

.. _cpu_backend_quick_start_dockerfile:
Expand All @@ -41,7 +42,7 @@ Quick start using Dockerfile
Build from source
-----------------

- First, install required compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
- First, install recommended compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:

.. code-block:: console
Expand Down Expand Up @@ -70,13 +71,31 @@ Build from source

- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.

.. _ipex_guidance:

Intel Extension for PyTorch
---------------------------

- `Intel Extension for PyTorch (IPEX) <https://github.com/intel/intel-extension-for-pytorch>`_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.

- IPEX after the ``2.3.0`` can be enabled in the CPU backend by default if it is installed.

.. _cpu_backend_performance_tips:

Performance tips
-----------------

- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.

- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:

.. code-block:: console
$ sudo apt-get install libtcmalloc-minimal4 # install TCMalloc library
$ find / -name *libtcmalloc* # find the dynamic link library path
$ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD
$ python examples/offline_inference.py # run vLLM
- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription.

- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading.
Expand Down
2 changes: 1 addition & 1 deletion requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
-r requirements-common.txt

# Dependencies for x86_64 CPUs
torch == 2.3.0+cpu
torch == 2.3.1+cpu
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
23 changes: 16 additions & 7 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,16 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu

if is_cpu():
try:
from vllm.attention.ops.ipex_attn import PagedAttention
except ImportError:
from vllm.attention.ops.paged_attn import PagedAttention
else:
from vllm.attention.ops.paged_attn import PagedAttention


class TorchSDPABackend(AttentionBackend):
Expand Down Expand Up @@ -197,13 +205,14 @@ def forward(
attn_metadata.attn_bias):
end = start + seq_len
sub_out = scaled_dot_product_attention(
query[:, start:end, :],
key[:, start:end, :],
value[:, start:end, :],
query[None, :, start:end, :],
key[None, :, start:end, :],
value[None, :, start:end, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0)
scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out
start = end
else:
Expand Down Expand Up @@ -248,7 +257,7 @@ def _make_alibi_bias(

num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None])
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = torch.empty(
(1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
Expand Down
120 changes: 120 additions & 0 deletions vllm/attention/ops/ipex_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Dict, List, Optional, Tuple

import intel_extension_for_pytorch.llm.modules as ipex_modules
import torch

from vllm import _custom_ops as ops


class PagedAttention:

@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)

@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[torch.Tensor, torch.Tensor]:
num_blocks = kv_cache.shape[1]

key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
return key_cache, value_cache

@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
*args,
) -> None:
ipex_modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache,
slot_mapping.flatten().int())

@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
kv_scale: float,
*args,
) -> torch.Tensor:
output = torch.empty_like(query)
block_size = value_cache.shape[2]
head_mapping = torch.arange(
0,
num_kv_heads,
device="cpu",
dtype=torch.int32,
).view(num_kv_heads,
1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
ipex_modules.PagedAttention.single_query_cached_kv_attention(
output, query.contiguous(), key_cache, value_cache, head_mapping,
scale, block_tables, context_lens, block_size, max_context_len,
alibi_slopes)

return output

@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor],
*args,
) -> torch.Tensor:
raise NotImplementedError

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
*args,
) -> None:
raise NotImplementedError

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)

0 comments on commit 5295b12

Please sign in to comment.