Skip to content

Commit

Permalink
[Hardware][TPU] Optimize KV cache swapping (vllm-project#5878)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
WoosukKwon authored and Alvant committed Oct 26, 2024
1 parent 44d625a commit 3cb5dfd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
16 changes: 4 additions & 12 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,13 @@ def get_kv_cache_shape(
) -> Tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)

@torch.compile(backend="openxla")
@staticmethod
def swap_blocks(
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
src_to_dst: Tuple[torch.Tensor, torch.Tensor],
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
src_k_cache, src_v_cache = src_kv_cache
dst_k_cache, dst_v_cache = dst_kv_cache
src_indices, dst_indices = src_to_dst
device = dst_k_cache.device
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)
raise RuntimeError("swap_blocks is not used for the TPU backend.")

@torch.compile(backend="openxla")
@staticmethod
Expand Down
42 changes: 32 additions & 10 deletions vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch_xla.core.xla_model as xm
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
import torch_xla.runtime as xr

import vllm.envs as envs
Expand Down Expand Up @@ -152,8 +153,8 @@ def initialize_cache(
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
head_size = self.model_config.get_head_size()

self.cpu_cache = []
self.tpu_cache = []
self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
Expand Down Expand Up @@ -227,18 +228,25 @@ def cache_swap(

if blocks_to_swap_in:
# Swap from CPU to TPU.
src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu",
self.device)
src_indices, dst_indices = _make_src_to_dst(
blocks_to_swap_in, "cpu", self.device)
for i in range(num_layers):
attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i],
src_to_dst)
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
k = cpu_k_cache[:, src_indices].to(self.device)
v = cpu_v_cache[:, src_indices].to(self.device)
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)

if blocks_to_swap_out:
# Swap from TPU to CPU.
src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device,
"cpu")
src_indices, dst_indices = _make_src_to_dst(
blocks_to_swap_out, self.device, "cpu")
for i in range(num_layers):
attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i],
src_to_dst)
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu()
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu()

if blocks_to_copy:
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
self.device)
Expand Down Expand Up @@ -267,3 +275,17 @@ def _make_src_to_dst(
device=dst_device,
dtype=torch.int64)
return src_indices, dst_indices


@torch.compile(backend="openxla")
def _insert_kv(
k: torch.Tensor,
v: torch.Tensor,
indices: torch.Tensor,
tpu_k_cache: torch.Tensor,
tpu_v_cache: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
tpu_k_cache[:, indices] = k
tpu_v_cache[:, indices] = v

0 comments on commit 3cb5dfd

Please sign in to comment.