Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions vllm_ascend/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
from vllm.transformers_utils.configs import Qwen3NextConfig
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata

from vllm_ascend.utils import vllm_version_is

from vllm.model_executor.models.qwen3_next import ( # isort: skip
Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextForCausalLM,
Qwen3NextGatedDeltaNet, Qwen3NextModel, Qwen3NextSparseMoeBlock,
Expand Down Expand Up @@ -201,7 +203,11 @@ def _forward(
spec_query_start_loc = attn_metadata.spec_query_start_loc
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
spec_sequence_masks = attn_metadata.spec_sequence_masks
spec_token_masks = attn_metadata.spec_token_masks
if vllm_version_is("0.11.0"):
spec_token_masks = attn_metadata.spec_token_masks
else:
spec_token_indx = attn_metadata.spec_token_indx
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
Expand All @@ -216,8 +222,9 @@ def _forward(

# 1. Set up dimensions for reshapes later
projected_states, _ = self.in_proj(hidden_states[:num_actual_tokens])
if spec_token_masks is not None:
spec_token_masks = spec_token_masks[:num_actual_tokens]
if vllm_version_is("0.11.0"):
if spec_token_masks is not None:
spec_token_masks = spec_token_masks[:num_actual_tokens]
projected_states_qkvz, projected_states_ba = torch.split(
projected_states,
[
Expand All @@ -242,8 +249,13 @@ def _forward(
mixed_qkv_spec = mixed_qkv
mixed_qkv_non_spec = None
else:
mixed_qkv_spec = mixed_qkv[spec_token_masks]
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
if vllm_version_is("0.11.0"):
mixed_qkv_spec = mixed_qkv[spec_token_masks]
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
else:
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
mixed_qkv_non_spec = mixed_qkv.index_select(
0, non_spec_token_indx)
else:
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv
Expand Down Expand Up @@ -293,10 +305,16 @@ def _forward(
g_non_spec = None
beta_non_spec = None
else:
g_spec = g[:, spec_token_masks]
beta_spec = beta[:, spec_token_masks]
g_non_spec = g[:, ~spec_token_masks]
beta_non_spec = beta[:, ~spec_token_masks]
if vllm_version_is("0.11.0"):
g_spec = g[:, spec_token_masks]
beta_spec = beta[:, spec_token_masks]
g_non_spec = g[:, ~spec_token_masks]
beta_non_spec = beta[:, ~spec_token_masks]
else:
g_spec = g.index_select(1, spec_token_indx)
beta_spec = beta.index_select(1, spec_token_indx)
g_non_spec = g.index_select(1, non_spec_token_indx)
beta_non_spec = beta.index_select(1, non_spec_token_indx)
else:
g_spec = None
beta_spec = None
Expand Down Expand Up @@ -404,8 +422,14 @@ def _forward(
dtype=core_attn_out_non_spec.dtype,
device=core_attn_out_non_spec.device,
)
core_attn_out[:, spec_token_masks] = core_attn_out_spec
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
if vllm_version_is("0.11.0"):
core_attn_out[:, spec_token_masks] = core_attn_out_spec
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
else:
core_attn_out.index_copy_(1, spec_token_indx,
core_attn_out_spec)
core_attn_out.index_copy_(1, non_spec_token_indx,
core_attn_out_non_spec)
elif spec_sequence_masks is not None:
core_attn_out = core_attn_out_spec
else:
Expand Down
Loading