Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][intel GPU] bump up ipex version to 2.3 #8365

Merged
merged 10 commits into from
Sep 13, 2024
Prev Previous commit
Next Next commit
fix copy_blocks
  • Loading branch information
jikunshang committed Sep 11, 2024
commit 3e0a8f45e412bffbee29b472467ed54eb3c70890
8 changes: 6 additions & 2 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ def swap_blocks(
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
from vllm._ipex_ops import ipex_ops as ops
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
from vllm._ipex_ops import ipex_ops as ops
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)


@dataclass
Expand Down