Skip to content

Commit e196d22

Browse files
committed
support Kimi-VL-A3B-thinking on xpu (vllm-project#11)
Signed-off-by: Yan Ma <yan.ma@intel.com>
1 parent 37cf720 commit e196d22

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

vllm/model_executor/models/moonvit.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,13 @@
5555

5656
from vllm.model_executor.layers.linear import ReplicatedLinear
5757
from vllm.model_executor.models.utils import maybe_prefix
58+
from vllm.platforms import current_platform
5859
from vllm.transformers_utils.configs.moonvit import MoonViTConfig
5960

6061
if is_flash_attn_2_available():
6162
from flash_attn import flash_attn_varlen_func
63+
elif current_platform.is_xpu():
64+
from vllm.attention.utils.fa_utils import flash_attn_varlen_func
6265
else:
6366
flash_attn_varlen_func = None
6467

@@ -105,10 +108,10 @@ def multihead_attention(
105108
q,
106109
k,
107110
v,
108-
q_cu_seqlens,
109-
k_cu_seqlens,
110-
max_seqlen_q,
111-
max_seqlen_k,
111+
cu_seqlens_q=q_cu_seqlens,
112+
cu_seqlens_k=k_cu_seqlens,
113+
max_seqlen_q=max_seqlen_q,
114+
max_seqlen_k=max_seqlen_k,
112115
causal=False,
113116
)
114117
attn_out = attn_out.flatten(start_dim=-2)
@@ -290,7 +293,12 @@ class Rope2DPosEmb(nn.Module):
290293
"""
291294

292295
def __init__(
293-
self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
296+
self,
297+
dim: int,
298+
max_height: int,
299+
max_width: int,
300+
theta_base=10000,
301+
device=current_platform.device_type,
294302
):
295303
super().__init__()
296304
self.dim = dim
@@ -436,7 +444,7 @@ def __init__(
436444
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
437445
self.attn_implementation = attn_implementation
438446
# use fa2 in vllm by default
439-
if is_flash_attn_2_available():
447+
if is_flash_attn_2_available() or current_platform.is_xpu():
440448
self.attn_implementation = "flash_attention_2"
441449

442450
self.norm0 = nn.LayerNorm(hidden_dim)

0 commit comments

Comments
 (0)