|
55 | 55 |
|
56 | 56 | from vllm.model_executor.layers.linear import ReplicatedLinear |
57 | 57 | from vllm.model_executor.models.utils import maybe_prefix |
| 58 | +from vllm.platforms import current_platform |
58 | 59 | from vllm.transformers_utils.configs.moonvit import MoonViTConfig |
59 | 60 |
|
60 | 61 | if is_flash_attn_2_available(): |
61 | 62 | from flash_attn import flash_attn_varlen_func |
| 63 | +elif current_platform.is_xpu(): |
| 64 | + from vllm.attention.utils.fa_utils import flash_attn_varlen_func |
62 | 65 | else: |
63 | 66 | flash_attn_varlen_func = None |
64 | 67 |
|
@@ -105,10 +108,10 @@ def multihead_attention( |
105 | 108 | q, |
106 | 109 | k, |
107 | 110 | v, |
108 | | - q_cu_seqlens, |
109 | | - k_cu_seqlens, |
110 | | - max_seqlen_q, |
111 | | - max_seqlen_k, |
| 111 | + cu_seqlens_q=q_cu_seqlens, |
| 112 | + cu_seqlens_k=k_cu_seqlens, |
| 113 | + max_seqlen_q=max_seqlen_q, |
| 114 | + max_seqlen_k=max_seqlen_k, |
112 | 115 | causal=False, |
113 | 116 | ) |
114 | 117 | attn_out = attn_out.flatten(start_dim=-2) |
@@ -290,7 +293,12 @@ class Rope2DPosEmb(nn.Module): |
290 | 293 | """ |
291 | 294 |
|
292 | 295 | def __init__( |
293 | | - self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda" |
| 296 | + self, |
| 297 | + dim: int, |
| 298 | + max_height: int, |
| 299 | + max_width: int, |
| 300 | + theta_base=10000, |
| 301 | + device=current_platform.device_type, |
294 | 302 | ): |
295 | 303 | super().__init__() |
296 | 304 | self.dim = dim |
@@ -436,7 +444,7 @@ def __init__( |
436 | 444 | self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads |
437 | 445 | self.attn_implementation = attn_implementation |
438 | 446 | # use fa2 in vllm by default |
439 | | - if is_flash_attn_2_available(): |
| 447 | + if is_flash_attn_2_available() or current_platform.is_xpu(): |
440 | 448 | self.attn_implementation = "flash_attention_2" |
441 | 449 |
|
442 | 450 | self.norm0 = nn.LayerNorm(hidden_dim) |
|
0 commit comments