Skip to content

Commit f5c8628

Browse files
authored
[Bugfix][TPU] Fix CPU cache allocation (#5869)
1 parent cbc53b6 commit f5c8628

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

vllm/attention/backends/pallas.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,10 @@ def swap_blocks(
3737
) -> None:
3838
src_k_cache, src_v_cache = src_kv_cache
3939
dst_k_cache, dst_v_cache = dst_kv_cache
40+
src_indices, dst_indices = src_to_dst
41+
device = dst_k_cache.device
4042
torch.ops.xla.dynamo_set_buffer_donor_(dst_k_cache, True)
4143
torch.ops.xla.dynamo_set_buffer_donor_(dst_v_cache, True)
42-
43-
device = dst_k_cache.device
44-
src_indices, dst_indices = src_to_dst
4544
dst_k_cache[:, dst_indices] = src_k_cache[:, src_indices].to(device)
4645
dst_v_cache[:, dst_indices] = src_v_cache[:, src_indices].to(device)
4746

vllm/worker/tpu_worker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,18 @@ def initialize_cache(
156156
self.tpu_cache = []
157157
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
158158
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
159+
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
160+
num_cpu_blocks, self.block_size, num_kv_heads, head_size)
159161
for _ in range(num_layers):
160162
tpu_k_cache = torch.zeros(tpu_cache_shape,
161163
dtype=dtype,
162164
device=self.device)
163165
tpu_v_cache = torch.zeros_like(tpu_k_cache)
164166
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
165-
cpu_k_cache = torch.zeros_like(tpu_k_cache, device="cpu")
166-
cpu_v_cache = torch.zeros_like(tpu_v_cache, device="cpu")
167+
cpu_k_cache = torch.zeros(cpu_cache_shape,
168+
dtype=dtype,
169+
device="cpu")
170+
cpu_v_cache = torch.zeros_like(cpu_k_cache)
167171
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
168172
self._warmup_model()
169173

0 commit comments

Comments
 (0)