Skip to content

Commit b99fd88

Browse files
vadiklyutiyalbertoperdomo2
authored andcommitted
[PERF] Qwen3-next MTP speedup (change bool mask indexing to index_select / index_copy to reduce d2h) (vllm-project#26437)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent fe94c23 commit b99fd88

File tree

3 files changed

+56
-36
lines changed

3 files changed

+56
-36
lines changed

vllm/model_executor/layers/fla/ops/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]
4545
"""
4646

4747
cache_entries: tuple[tuple | None, dict | None, Any] = []
48-
cache_size = 4
48+
cache_size = 8
4949

5050
@functools.wraps(fn)
5151
def wrapper(*args: Any, **kwargs: Any) -> Any:

vllm/model_executor/models/qwen3_next.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def rearrange_mixed_qkv(self, mixed_qkv):
423423
(query, key),
424424
)
425425
value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
426-
return query, key, value
426+
return query.contiguous(), key.contiguous(), value.contiguous()
427427

428428
def forward(
429429
self,
@@ -455,16 +455,15 @@ def _forward(
455455
spec_query_start_loc = attn_metadata.spec_query_start_loc
456456
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
457457
spec_sequence_masks = attn_metadata.spec_sequence_masks
458-
spec_token_masks = attn_metadata.spec_token_masks
458+
spec_token_indx = attn_metadata.spec_token_indx
459+
non_spec_token_indx = attn_metadata.non_spec_token_indx
459460
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
460461
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
461462
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
462463
conv_state = self_kv_cache[0].transpose(-1, -2)
463464
ssm_state = self_kv_cache[1]
464465
num_actual_tokens = attn_metadata.num_actual_tokens
465466
num_accepted_tokens = attn_metadata.num_accepted_tokens
466-
if spec_token_masks is not None:
467-
spec_token_masks = spec_token_masks[:num_actual_tokens]
468467

469468
# 1. Set up dimensions for reshapes later
470469
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens])
@@ -487,8 +486,8 @@ def _forward(
487486
mixed_qkv_spec = mixed_qkv
488487
mixed_qkv_non_spec = None
489488
else:
490-
mixed_qkv_spec = mixed_qkv[spec_token_masks]
491-
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
489+
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
490+
mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
492491
else:
493492
mixed_qkv_spec = None
494493
mixed_qkv_non_spec = mixed_qkv
@@ -558,10 +557,10 @@ def _forward(
558557
g_non_spec = None
559558
beta_non_spec = None
560559
else:
561-
g_spec = g[:, spec_token_masks]
562-
beta_spec = beta[:, spec_token_masks]
563-
g_non_spec = g[:, ~spec_token_masks]
564-
beta_non_spec = beta[:, ~spec_token_masks]
560+
g_spec = g.index_select(1, spec_token_indx)
561+
beta_spec = beta.index_select(1, spec_token_indx)
562+
g_non_spec = g.index_select(1, non_spec_token_indx)
563+
beta_non_spec = beta.index_select(1, non_spec_token_indx)
565564
else:
566565
g_spec = None
567566
beta_spec = None
@@ -638,8 +637,9 @@ def _forward(
638637
dtype=core_attn_out_non_spec.dtype,
639638
device=core_attn_out_non_spec.device,
640639
)
641-
core_attn_out[:, spec_token_masks] = core_attn_out_spec
642-
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
640+
core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
641+
core_attn_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
642+
643643
elif spec_sequence_masks is not None:
644644
core_attn_out = core_attn_out_spec
645645
else:

vllm/v1/attention/backends/gdn_attn.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ class GDNAttentionMetadata:
4747
None # shape: [batch - num_spec_decodes,]
4848
)
4949
spec_sequence_masks: torch.Tensor | None = None # shape: [batch,]
50-
spec_token_masks: torch.Tensor | None = (
51-
None # shape: [num_prefill_tokens + num_decode_tokens,]
52-
)
50+
spec_token_indx: torch.Tensor | None = None
51+
non_spec_token_indx: torch.Tensor | None = None
52+
5353
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
5454

5555
# The following attributes are for triton implementation of causal_conv1d
@@ -105,9 +105,14 @@ def __init__(
105105
dtype=torch.bool,
106106
device=device,
107107
)
108-
self.spec_token_masks = torch.empty(
108+
self.spec_token_indx = torch.empty(
109109
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
110-
dtype=torch.bool,
110+
dtype=torch.int32,
111+
device=device,
112+
)
113+
self.non_spec_token_indx = torch.empty(
114+
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
115+
dtype=torch.int32,
111116
device=device,
112117
)
113118
self.spec_query_start_loc = torch.empty(
@@ -166,7 +171,8 @@ def build( # type: ignore[override]
166171
split_decodes_and_prefills(m, decode_threshold=1)
167172
)
168173
num_spec_decode_tokens = 0
169-
spec_token_masks = None
174+
spec_token_indx = None
175+
non_spec_token_indx = None
170176
spec_state_indices_tensor = None
171177
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
172178
spec_query_start_loc = None
@@ -180,18 +186,23 @@ def build( # type: ignore[override]
180186
num_prefills = non_spec_query_lens.size(0) - num_decodes
181187
num_decode_tokens = num_decodes
182188
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
189+
num_spec_decode_tokens = (
190+
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
191+
)
183192

184193
if num_prefills == 0 and num_decodes == 0:
185-
spec_token_masks = torch.ones(
186-
(
187-
min(
188-
num_spec_decodes * (self.num_spec + 1),
189-
query_start_loc[-1].item(),
190-
)
191-
),
192-
dtype=torch.bool,
194+
spec_token_size = min(
195+
num_spec_decodes * (self.num_spec + 1),
196+
query_start_loc[-1].item(),
197+
)
198+
spec_token_indx = torch.arange(
199+
spec_token_size,
200+
dtype=torch.int32,
193201
device=query_start_loc.device,
194202
)
203+
non_spec_token_indx = torch.empty(
204+
0, dtype=torch.int32, device=query_start_loc.device
205+
)
195206
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
196207
non_spec_state_indices_tensor = None
197208
spec_query_start_loc = query_start_loc
@@ -200,6 +211,11 @@ def build( # type: ignore[override]
200211
spec_token_masks = torch.repeat_interleave(
201212
spec_sequence_masks, query_lens
202213
)
214+
index = torch.argsort(spec_token_masks)
215+
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
216+
non_spec_token_indx = index[:num_non_spec_tokens]
217+
spec_token_indx = index[num_non_spec_tokens:]
218+
203219
spec_state_indices_tensor = m.block_table_tensor[
204220
spec_sequence_masks, : self.num_spec + 1
205221
]
@@ -226,9 +242,6 @@ def build( # type: ignore[override]
226242
out=non_spec_query_start_loc[1:],
227243
)
228244

229-
num_spec_decode_tokens = (
230-
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
231-
)
232245
assert num_accepted_tokens is not None
233246
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
234247

@@ -274,12 +287,18 @@ def build( # type: ignore[override]
274287
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
275288
spec_sequence_masks[num_spec_decodes:].fill_(False)
276289

277-
assert spec_token_masks is not None
278-
self.spec_token_masks[: spec_token_masks.size(0)].copy_(
279-
spec_token_masks, non_blocking=True
290+
assert non_spec_token_indx is not None and spec_token_indx is not None
291+
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
292+
non_spec_token_indx, non_blocking=True
293+
)
294+
non_spec_token_indx = self.non_spec_token_indx[
295+
: non_spec_token_indx.size(0)
296+
]
297+
298+
self.spec_token_indx[: spec_token_indx.size(0)].copy_(
299+
spec_token_indx, non_blocking=True
280300
)
281-
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
282-
spec_token_masks[spec_token_masks.size(0) :].fill_(False)
301+
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]
283302

284303
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
285304
spec_query_start_loc, non_blocking=True
@@ -332,7 +351,8 @@ def build( # type: ignore[override]
332351
spec_state_indices_tensor=spec_state_indices_tensor,
333352
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
334353
spec_sequence_masks=spec_sequence_masks,
335-
spec_token_masks=spec_token_masks,
354+
spec_token_indx=spec_token_indx,
355+
non_spec_token_indx=non_spec_token_indx,
336356
num_accepted_tokens=num_accepted_tokens,
337357
nums_dict=nums_dict,
338358
batch_ptr=batch_ptr,

0 commit comments

Comments
 (0)