Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 116 additions & 16 deletions vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ def _causal_conv1d_update_kernel(
cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
num_accepted_tokens_ptr,
query_start_loc_ptr, # (batch + 1)
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
Expand All @@ -652,6 +653,7 @@ def _causal_conv1d_update_kernel(
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
Expand All @@ -678,6 +680,25 @@ def _causal_conv1d_update_kernel(
# not processing as this is not the actual sequence
return

if IS_VARLEN:
query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64)
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(
tl.int64)
# revise state_len and seqlen
state_len = state_len - (seqlen -
(query_end_index - query_start_index))
seqlen = query_end_index - query_start_index
x_offset = query_start_index * stride_x_token
o_offset = query_start_index * stride_o_token
else:
query_start_index = idx_seq * seqlen
query_end_index = query_start_index + seqlen
x_offset = idx_seq * stride_x_seq
o_offset = idx_seq * stride_o_seq

if query_start_index == query_end_index:
return

if IS_SPEC_DECODING:
# The rolling of conv state:
#
Expand All @@ -692,8 +713,8 @@ def _causal_conv1d_update_kernel(
# - accept 1 tokens: [history2, ..., historyM, draft1]
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) -
1)
conv_state_token_offset = (
tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1)
else:
conv_state_token_offset = 0

Expand All @@ -713,9 +734,12 @@ def _causal_conv1d_update_kernel(
if KERNEL_WIDTH >= 4:
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 5:
if KERNEL_WIDTH >= 5:
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 6:
conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N]
col4 = tl.load(conv_states_ptrs, mask_w, 0.0)

# STEP 2: assume state_len > seqlen
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
Expand All @@ -735,8 +759,7 @@ def _causal_conv1d_update_kernel(
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)

VAL = state_len - seqlen
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
) # [BLOCK_N]
x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N]

x_ptrs = x_base[None, :] + (
(idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
Expand Down Expand Up @@ -782,12 +805,18 @@ def _causal_conv1d_update_kernel(
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 5:
w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor
w_col4 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 6:
w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor
w_col5 = tl.load(w_ptrs, mask_w, other=0.0)

x_base_1d = x_base # starting of chunk [BLOCK_N]
mask_x_1d = idx_feats < dim

# STEP 5: compute each token
for idx_token in tl.static_range(seqlen):
for idx_token in tl.range(seqlen):
acc = acc_preload

matrix_w = w_col0
Expand Down Expand Up @@ -817,6 +846,37 @@ def _causal_conv1d_update_kernel(
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 5:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
matrix_x = col3
elif j == 4:
matrix_w = w_col4
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 6:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
matrix_x = col3
elif j == 4:
matrix_w = w_col4
matrix_x = col4
elif j == 5:
matrix_w = w_col5
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)

acc += matrix_x * matrix_w # [BLOCK_N]

Expand All @@ -829,14 +889,24 @@ def _causal_conv1d_update_kernel(
col0 = col1
col1 = col2
col2 = matrix_x
elif KERNEL_WIDTH == 5:
col0 = col1
col1 = col2
col2 = col3
col3 = matrix_x
elif KERNEL_WIDTH == 6:
col0 = col1
col1 = col2
col2 = col3
col3 = col4
col4 = matrix_x

if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
mask_1d = (idx_token < seqlen) & (idx_feats < dim
) # token-index # feature-index
o_ptrs = o_ptr + (
idx_seq) * stride_o_seq + idx_token * stride_o_token + (
idx_feats * stride_o_dim)
o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats *
stride_o_dim)

tl.store(o_ptrs, acc, mask=mask_1d)

Expand All @@ -850,14 +920,18 @@ def causal_conv1d_update(
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
x: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim)
[shape=2: single token prediction]
[shape=3: single or multiple tokens prediction]
[shape=2 with num_tokens: continuous batching, where num_tokens is the
total tokens of all sequences in that batch]
conv_state: (..., dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
Expand All @@ -870,13 +944,24 @@ def causal_conv1d_update(
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
num_accepted_tokens: (batch,), dtype int32
If not None, it indicates the number of accepted tokens for each
sequence in the batch.
This is used in speculative decoding, where the conv_state is updated
in a sliding window manner.
query_start_loc: (batch + 1,) int32
If not None, the inputs is given in a varlen fashion and this indicates
the starting index of each sequence in the batch.
max_query_len: int
If query_start_loc is not None, this indicates the maximum query
length in the batch.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
"""
if validate_data:
assert cache_seqlens is None # not implemented yet - ok for vLLM
Expand All @@ -886,11 +971,17 @@ def causal_conv1d_update(
activation = "silu" if activation is True else None
elif activation is not None:
assert activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
unsqueeze = query_start_loc is None and x.dim() == 2
if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
if query_start_loc is None:
batch, dim, seqlen = x.shape
else:
assert conv_state_indices is not None
batch = conv_state_indices.size(0)
dim = x.size(1)
seqlen = max_query_len
_, width = weight.shape
# conv_state: (..., dim, state_len), where state_len >= width - 1
num_cache_lines, _, state_len = conv_state.size()
Expand All @@ -916,10 +1007,17 @@ def causal_conv1d_update(
out = x
stride_w_dim, stride_w_width = weight.stride()

stride_x_seq, stride_x_dim, stride_x_token = x.stride(
) # X (batch, dim, seqlen)
if query_start_loc is None:
# X (batch, dim, seqlen)
stride_x_seq, stride_x_dim, stride_x_token = x.stride()
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
else:
# X (dim, cu_seqlen)
stride_x_token, stride_x_dim = x.stride()
stride_x_seq = 0
stride_o_token, stride_o_dim = out.stride()
stride_o_seq = 0

stride_o_seq, stride_o_dim, stride_o_token = out.stride()
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
)
stride_state_indices = conv_state_indices.stride(
Expand All @@ -945,6 +1043,7 @@ def grid(META):
cache_seqlens,
conv_state_indices,
num_accepted_tokens,
query_start_loc,
out,
# Matrix dimensions
batch,
Expand All @@ -971,6 +1070,7 @@ def grid(META):
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_VARLEN=query_start_loc is not None,
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
Expand Down
10 changes: 3 additions & 7 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,7 @@ def _forward(
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = (attn_metadata.num_prefill_tokens +
attn_metadata.num_decode_tokens +
attn_metadata.num_spec_decode_tokens)
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens

# 1. Set up dimensions for reshapes later
Expand Down Expand Up @@ -458,9 +456,6 @@ def _forward(

# 2.1: process the mutli-query part
if spec_sequence_masks is not None:
mixed_qkv_spec = mixed_qkv_spec.view(
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec,
conv_state,
Expand All @@ -470,9 +465,10 @@ def _forward(
conv_state_indices=spec_state_indices_tensor[:, 0]
[:attn_metadata.num_spec_decodes],
num_accepted_tokens=num_accepted_tokens,
query_start_loc=spec_query_start_loc,
max_query_len=spec_state_indices_tensor.size(-1),
validate_data=False,
)
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')

# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:
Expand Down
31 changes: 20 additions & 11 deletions vllm/v1/attention/backends/gdn_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class GDNAttentionMetadata:
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
num_actual_tokens: int

has_initial_state: Optional[torch.Tensor] = None

Expand Down Expand Up @@ -74,8 +75,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_capture_size)
self.vllm_config.scheduler_config.max_num_seqs *
(self.num_spec + 1), self.compilation_config.max_capture_size)
Comment on lines 77 to +79
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Changing self.decode_cudagraph_max_bs to be a token count (by multiplying with self.num_spec + 1) is incorrect, as this variable is used as a sequence count (batch size) for tensor allocations. For example, self.spec_state_indices_tensor is allocated with this as its first dimension (line 80), which is indexed by sequence, not by token. This change will lead to incorrect tensor allocations (either too large, wasting memory, or too small, causing out-of-bounds errors) and likely runtime failures.

To fix this correctly, decode_cudagraph_max_bs should remain a sequence count. A new variable should be introduced for the maximum token count if needed for the check at line 221.


self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
Expand Down Expand Up @@ -194,9 +195,8 @@ def build( # type: ignore[override]
dim=0,
out=non_spec_query_start_loc[1:])

num_spec_decode_tokens = min(
num_spec_decodes * (self.num_spec + 1),
spec_token_masks.size(0))
num_spec_decode_tokens = (query_lens.sum().item() -
num_prefill_tokens - num_decode_tokens)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]

Expand All @@ -206,14 +206,22 @@ def build( # type: ignore[override]
has_initial_state = has_initial_state[~spec_sequence_masks]
else:
has_initial_state = None
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
num_spec_decode_tokens

# prepare tensors for cudagraph
#
# With speculative decoding, the xgrammar backend may rollback tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs):
num_total_tokens = self.vllm_config.pad_for_cudagraph(
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_total_tokens // (self.num_spec + 1)
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)

self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True)
Expand All @@ -229,7 +237,7 @@ def build( # type: ignore[override]
assert spec_token_masks is not None
self.spec_token_masks[:spec_token_masks.size(0)].copy_(
spec_token_masks, non_blocking=True)
spec_token_masks = self.spec_token_masks[:m.num_actual_tokens]
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
spec_token_masks[spec_token_masks.size(0):].fill_(False)

self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
Expand All @@ -248,9 +256,9 @@ def build( # type: ignore[override]
if (self.use_full_cuda_graph and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs):
num_total_tokens = self.vllm_config.pad_for_cudagraph(
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_total_tokens
batch_size = num_actual_tokens

self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True)
Expand All @@ -274,6 +282,7 @@ def build( # type: ignore[override]
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=num_actual_tokens,
has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,
Expand Down