|
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | import torch_xla.core.xla_model as xm
|
| 6 | +import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401 |
6 | 7 | import torch_xla.runtime as xr
|
7 | 8 |
|
8 | 9 | import vllm.envs as envs
|
@@ -152,8 +153,8 @@ def initialize_cache(
|
152 | 153 | num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
153 | 154 | head_size = self.model_config.get_head_size()
|
154 | 155 |
|
155 |
| - self.cpu_cache = [] |
156 |
| - self.tpu_cache = [] |
| 156 | + self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] |
| 157 | + self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] |
157 | 158 | tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
158 | 159 | num_gpu_blocks, self.block_size, num_kv_heads, head_size)
|
159 | 160 | cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
@@ -227,18 +228,25 @@ def cache_swap(
|
227 | 228 |
|
228 | 229 | if blocks_to_swap_in:
|
229 | 230 | # Swap from CPU to TPU.
|
230 |
| - src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu", |
231 |
| - self.device) |
| 231 | + src_indices, dst_indices = _make_src_to_dst( |
| 232 | + blocks_to_swap_in, "cpu", self.device) |
232 | 233 | for i in range(num_layers):
|
233 |
| - attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i], |
234 |
| - src_to_dst) |
| 234 | + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] |
| 235 | + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] |
| 236 | + k = cpu_k_cache[:, src_indices].to(self.device) |
| 237 | + v = cpu_v_cache[:, src_indices].to(self.device) |
| 238 | + _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) |
| 239 | + |
235 | 240 | if blocks_to_swap_out:
|
236 | 241 | # Swap from TPU to CPU.
|
237 |
| - src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device, |
238 |
| - "cpu") |
| 242 | + src_indices, dst_indices = _make_src_to_dst( |
| 243 | + blocks_to_swap_out, self.device, "cpu") |
239 | 244 | for i in range(num_layers):
|
240 |
| - attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i], |
241 |
| - src_to_dst) |
| 245 | + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] |
| 246 | + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] |
| 247 | + cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu() |
| 248 | + cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu() |
| 249 | + |
242 | 250 | if blocks_to_copy:
|
243 | 251 | src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
|
244 | 252 | self.device)
|
@@ -267,3 +275,17 @@ def _make_src_to_dst(
|
267 | 275 | device=dst_device,
|
268 | 276 | dtype=torch.int64)
|
269 | 277 | return src_indices, dst_indices
|
| 278 | + |
| 279 | + |
| 280 | +@torch.compile(backend="openxla") |
| 281 | +def _insert_kv( |
| 282 | + k: torch.Tensor, |
| 283 | + v: torch.Tensor, |
| 284 | + indices: torch.Tensor, |
| 285 | + tpu_k_cache: torch.Tensor, |
| 286 | + tpu_v_cache: torch.Tensor, |
| 287 | +) -> None: |
| 288 | + torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True) |
| 289 | + torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True) |
| 290 | + tpu_k_cache[:, indices] = k |
| 291 | + tpu_v_cache[:, indices] = v |
0 commit comments