From acd5beefd2b67bf73cb4430842c2ffc22c667e6d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 27 Jun 2024 21:12:13 -0700 Subject: [PATCH] [Hardware][TPU] Optimize KV cache swapping (#5878) --- vllm/attention/backends/pallas.py | 16 +++--------- vllm/worker/tpu_worker.py | 42 +++++++++++++++++++++++-------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 5dec11e2eede7..22cb1a1bd1fd3 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -28,21 +28,13 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (num_kv_heads, num_blocks, block_size, head_size) - @torch.compile(backend="openxla") @staticmethod def swap_blocks( - src_kv_cache: Tuple[torch.Tensor, torch.Tensor], - dst_kv_cache: Tuple[torch.Tensor, torch.Tensor], - src_to_dst: Tuple[torch.Tensor, torch.Tensor], + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, ) -> None: - src_k_cache, src_v_cache = src_kv_cache - dst_k_cache, dst_v_cache = dst_kv_cache - src_indices, dst_indices = src_to_dst - device = dst_k_cache.device - torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True) - torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True) - dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device) - dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device) + raise RuntimeError("swap_blocks is not used for the TPU backend.") @torch.compile(backend="openxla") @staticmethod diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 28f460c31aa9b..37d810e8392a9 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -3,6 +3,7 @@ import torch import torch_xla.core.xla_model as xm +import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401 import torch_xla.runtime as xr import vllm.envs as envs @@ -152,8 +153,8 @@ def initialize_cache( num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - self.cpu_cache = [] - self.tpu_cache = [] + self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( num_gpu_blocks, self.block_size, num_kv_heads, head_size) cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( @@ -227,18 +228,25 @@ def cache_swap( if blocks_to_swap_in: # Swap from CPU to TPU. - src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu", - self.device) + src_indices, dst_indices = _make_src_to_dst( + blocks_to_swap_in, "cpu", self.device) for i in range(num_layers): - attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i], - src_to_dst) + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + k = cpu_k_cache[:, src_indices].to(self.device) + v = cpu_v_cache[:, src_indices].to(self.device) + _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) + if blocks_to_swap_out: # Swap from TPU to CPU. - src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device, - "cpu") + src_indices, dst_indices = _make_src_to_dst( + blocks_to_swap_out, self.device, "cpu") for i in range(num_layers): - attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i], - src_to_dst) + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu() + cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu() + if blocks_to_copy: src_to_dst = _make_src_to_dst(blocks_to_copy, self.device, self.device) @@ -267,3 +275,17 @@ def _make_src_to_dst( device=dst_device, dtype=torch.int64) return src_indices, dst_indices + + +@torch.compile(backend="openxla") +def _insert_kv( + k: torch.Tensor, + v: torch.Tensor, + indices: torch.Tensor, + tpu_k_cache: torch.Tensor, + tpu_v_cache: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True) + tpu_k_cache[:, indices] = k + tpu_v_cache[:, indices] = v