|
3 | 3 |
|
4 | 4 | import torch |
5 | 5 | import torch_xla.core.xla_model as xm |
| 6 | +import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401 |
6 | 7 | import torch_xla.runtime as xr |
7 | 8 |
|
8 | 9 | import vllm.envs as envs |
@@ -152,8 +153,8 @@ def initialize_cache( |
152 | 153 | num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) |
153 | 154 | head_size = self.model_config.get_head_size() |
154 | 155 |
|
155 | | - self.cpu_cache = [] |
156 | | - self.tpu_cache = [] |
| 156 | + self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] |
| 157 | + self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] |
157 | 158 | tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( |
158 | 159 | num_gpu_blocks, self.block_size, num_kv_heads, head_size) |
159 | 160 | cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( |
@@ -227,18 +228,25 @@ def cache_swap( |
227 | 228 |
|
228 | 229 | if blocks_to_swap_in: |
229 | 230 | # Swap from CPU to TPU. |
230 | | - src_to_dst = _make_src_to_dst(blocks_to_swap_in, "cpu", |
231 | | - self.device) |
| 231 | + src_indices, dst_indices = _make_src_to_dst( |
| 232 | + blocks_to_swap_in, "cpu", self.device) |
232 | 233 | for i in range(num_layers): |
233 | | - attn_backend.swap_blocks(self.cpu_cache[i], self.tpu_cache[i], |
234 | | - src_to_dst) |
| 234 | + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] |
| 235 | + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] |
| 236 | + k = cpu_k_cache[:, src_indices].to(self.device) |
| 237 | + v = cpu_v_cache[:, src_indices].to(self.device) |
| 238 | + _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) |
| 239 | + |
235 | 240 | if blocks_to_swap_out: |
236 | 241 | # Swap from TPU to CPU. |
237 | | - src_to_dst = _make_src_to_dst(blocks_to_swap_out, self.device, |
238 | | - "cpu") |
| 242 | + src_indices, dst_indices = _make_src_to_dst( |
| 243 | + blocks_to_swap_out, self.device, "cpu") |
239 | 244 | for i in range(num_layers): |
240 | | - attn_backend.swap_blocks(self.tpu_cache[i], self.cpu_cache[i], |
241 | | - src_to_dst) |
| 245 | + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] |
| 246 | + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] |
| 247 | + cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu() |
| 248 | + cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu() |
| 249 | + |
242 | 250 | if blocks_to_copy: |
243 | 251 | src_to_dst = _make_src_to_dst(blocks_to_copy, self.device, |
244 | 252 | self.device) |
@@ -267,3 +275,17 @@ def _make_src_to_dst( |
267 | 275 | device=dst_device, |
268 | 276 | dtype=torch.int64) |
269 | 277 | return src_indices, dst_indices |
| 278 | + |
| 279 | + |
| 280 | +@torch.compile(backend="openxla") |
| 281 | +def _insert_kv( |
| 282 | + k: torch.Tensor, |
| 283 | + v: torch.Tensor, |
| 284 | + indices: torch.Tensor, |
| 285 | + tpu_k_cache: torch.Tensor, |
| 286 | + tpu_v_cache: torch.Tensor, |
| 287 | +) -> None: |
| 288 | + torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True) |
| 289 | + torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True) |
| 290 | + tpu_k_cache[:, indices] = k |
| 291 | + tpu_v_cache[:, indices] = v |
0 commit comments