Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 5f1316e

Browse files
WoosukKwonRobert Shaw
authored andcommitted
[Hardware][TPU] Optimize KV cache swapping (vllm-project#5878)
1 parent 4d5e0b9 commit 5f1316e

File tree

2 files changed

+36
-22
lines changed

2 files changed

+36
-22
lines changed

vllm/attention/backends/pallas.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,13 @@ def get_kv_cache_shape(
2828
) -> Tuple[int, ...]:
2929
return (num_kv_heads, num_blocks, block_size, head_size)
3030

31-
@torch.compile(backend="openxla")
3231
@staticmethod
3332
def swap_blocks(
34-
src_kv_cache: Tuple[torch.Tensor, torch.Tensor],
35-
dst_kv_cache: Tuple[torch.Tensor, torch.Tensor],
36-
src_to_dst: Tuple[torch.Tensor, torch.Tensor],
33+
src_kv_cache: torch.Tensor,
34+
dst_kv_cache: torch.Tensor,
35+
src_to_dst: torch.Tensor,
3736
) -> None:
38-
src_k_cache, src_v_cache = src_kv_cache
39-
dst_k_cache, dst_v_cache = dst_kv_cache
40-
src_indices, dst_indices = src_to_dst
41-
device = dst_k_cache.device
42-
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
43-
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)
44-
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
45-
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)
37+
raise RuntimeError("swap_blocks is not used for the TPU backend.")
4638

4739
@torch.compile(backend="openxla")
4840
@staticmethod

vllm/worker/tpu_worker.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import torch_xla.core.xla_model as xm
6+
import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
67
import torch_xla.runtime as xr
78

89
import vllm.envs as envs
@@ -152,8 +153,8 @@ def initialize_cache(
152153
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
153154
head_size = self.model_config.get_head_size()
154155

155-
self.cpu_cache = []
156-
self.tpu_cache = []
156+
self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
157+
self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
157158
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
158159
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
159160
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
@@ -227,18 +228,25 @@ def cache_swap(
227228

228229
if blocks_to_swap_in:
229230
# Swap from CPU to TPU.
230-
src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu",
231-
self.device)
231+
src_indices, dst_indices = _make_src_to_dst(
232+
blocks_to_swap_in, "cpu", self.device)
232233
for i in range(num_layers):
233-
attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i],
234-
src_to_dst)
234+
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
235+
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
236+
k = cpu_k_cache[:, src_indices].to(self.device)
237+
v = cpu_v_cache[:, src_indices].to(self.device)
238+
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
239+
235240
if blocks_to_swap_out:
236241
# Swap from TPU to CPU.
237-
src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device,
238-
"cpu")
242+
src_indices, dst_indices = _make_src_to_dst(
243+
blocks_to_swap_out, self.device, "cpu")
239244
for i in range(num_layers):
240-
attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i],
241-
src_to_dst)
245+
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
246+
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
247+
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu()
248+
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu()
249+
242250
if blocks_to_copy:
243251
src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
244252
self.device)
@@ -267,3 +275,17 @@ def _make_src_to_dst(
267275
device=dst_device,
268276
dtype=torch.int64)
269277
return src_indices, dst_indices
278+
279+
280+
@torch.compile(backend="openxla")
281+
def _insert_kv(
282+
k: torch.Tensor,
283+
v: torch.Tensor,
284+
indices: torch.Tensor,
285+
tpu_k_cache: torch.Tensor,
286+
tpu_v_cache: torch.Tensor,
287+
) -> None:
288+
torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
289+
torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
290+
tpu_k_cache[:, indices] = k
291+
tpu_v_cache[:, indices] = v

0 commit comments

Comments
 (0)