Skip to content

[Model] Mamba2 causal conv1d Refactor to Split Prefill and Decode Requests for Corresponding Kernels #17146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/kernels/mamba/test_mamba_ssm_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from einops import rearrange, repeat

from vllm.model_executor.layers.mamba.mamba2_metadata import (
_seq_idx_to_chunk_indices_offsets)
_query_start_loc_to_chunk_indices_offsets)
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
from vllm.platforms import current_platform
Expand Down Expand Up @@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
last_taken, exhausted, n_heads,
d_head, itype):

chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])

Y, new_states = mamba_chunk_scan_combined(
X,
Expand Down
88 changes: 45 additions & 43 deletions vllm/model_executor/layers/mamba/mamba2_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

@dataclass
class Mamba2Metadata:
has_prefill: bool

has_initial_states: torch.Tensor
prep_initial_states: bool
Expand All @@ -24,21 +23,23 @@ class Mamba2Metadata:
chunk_offsets: torch.Tensor


def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):

# convert seq_idx to chunk indices and offsets
# - derive the cu_seqlens
_, cu_seqlens = torch.where(seq_idx.diff())
cu_seqlens += 1
cu_seqlens = query_start_loc[1:] # remove prepended 0

# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
> 0).sum()
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum()
chunk_indices = torch.arange(N,
dtype=torch.int,
device=query_start_loc.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=query_start_loc.device)

cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):

Expand All @@ -60,48 +61,49 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):

def prepare_mamba2_metadata(
chunk_size: int,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Mamba2Metadata:

# compute number of prefill and decode requests
# NOTE: in V0 we assume prefills are before decodes
num_prefills = attn_metadata.num_prefills
num_prefill_tokens = attn_metadata.num_prefill_tokens

seq_idx = None
chunk_indices, chunk_offsets = None, None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states = None
prep_initial_states = False
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = attn_metadata.context_lens_tensor > 0
# precompute flag to avoid device syncs later in mamba2 forwards
prep_initial_states = torch.any(has_initial_states).item()

has_prefill = attn_metadata.num_prefills > 0

seq_idx = None
chunk_indices, chunk_offsets = None, None
if has_prefill:
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
for i, (srt, end) in enumerate(
zip(
attn_metadata.query_start_loc,
attn_metadata.query_start_loc[1:],
)):
seq_idx[srt:end] = i
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:
if (isinstance(attn_metadata,
(FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = \
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
# precompute flag to avoid device syncs in mamba2 layer forwards
# prep is only needed for mamba2 ssd prefill processing
prep_initial_states = torch.any(has_initial_states).item()

query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device),
query_start_loc.diff(),
output_size=num_prefill_tokens)
seq_idx.unsqueeze_(0)

# compute metadata for chunked prefill.
# actually this is only needed if there are initial states,
# but this is determinable only from attention metadata yet
# unavailable from the top-level model forward. Rather than
# complicating things to extract said metadata, we simply just
# compute them once at the top level model forward and reuse
# them in mamba layers. If not needed, they will be ignored
# inside mamba kernels.
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)

return Mamba2Metadata(has_prefill=has_prefill,
has_initial_states=has_initial_states,
# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
if prep_initial_states:
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens)

return Mamba2Metadata(has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx=seq_idx,
Expand Down
178 changes: 104 additions & 74 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,15 @@ def forward_cuda(
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# are the same and reused for all mamba layers in the same iteration
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata

seq_len, _ = hidden_states.shape
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0

groups_time_state_size = self.n_groups * self.ssm_state_size

# 1. Gated MLP's linear projection
Expand All @@ -406,44 +411,32 @@ def forward_cuda(
dim=-1,
)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))

if mamba2_metadata.has_prefill:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|

# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
hidden_states_B_C = causal_conv1d_fn(
hidden_states_B_C.transpose(0, 1),
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=mamba2_metadata.has_initial_states,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc).transpose(
0, 1)[:seq_len]

# TODO: Why is this needed?
hidden_states_B_C = hidden_states_B_C.contiguous()
else:
hidden_states_B_C = causal_conv1d_update(
hidden_states_B_C,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
mamba_cache_params.state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
if has_prefill else None)

# - get hidden_states, B and C after depthwise convolution.
hidden_states, B, C = torch.split(
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
[
self.intermediate_size // self.tp_size,
Expand All @@ -453,32 +446,56 @@ def forward_cuda(
dim=-1,
)

# 3. State Space Model sequence transformation
if mamba2_metadata.has_prefill:
ssd_output_list = []

# Process prefill requests
if has_prefill:
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
hidden_states_B_C_p = causal_conv1d_fn(
hidden_states_B_C_p.transpose(0, 1),
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=mamba2_metadata.has_initial_states,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]

# TODO: Why is this needed?
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
hidden_states_B_C_p)

# 3. State Space Model sequence transformation
initial_states = None
if (mamba2_metadata.has_initial_states is not None
and mamba2_metadata.prep_initial_states):
# making a copy of the states
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor], 0)
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)

scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
self.head_dim),
dt.unsqueeze(0),
hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
dt_p.unsqueeze(0),
self.A,
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
-1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
-1),
chunk_size=mamba2_metadata.chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
seq_idx=mamba2_metadata.seq_idx,
chunk_indices=mamba2_metadata.chunk_indices,
chunk_offsets=mamba2_metadata.chunk_offsets,
cu_seqlens=attn_metadata.query_start_loc,
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
Expand All @@ -487,52 +504,65 @@ def forward_cuda(
)

# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor] = varlen_state
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state

# - reshape
hidden_states = scan_output.view(seq_len, -1)
else:
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))

# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d)

hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
hidden_states_B_C_d)

# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size
A = self.A[:, None, ...][:, :, None].expand(
A_d = self.A[:, None, ...][:, :, None].expand(
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(-1, n_groups, B.shape[1] // n_groups)
C = C.view(-1, n_groups, C.shape[1] // n_groups)
hidden_states_reshaped = hidden_states.view(
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
hidden_states_d = hidden_states_d.view(
-1, self.num_heads // self.tp_size, self.head_dim)

# - the hidden is reshaped into number of current batches
# - in this case there is no more prefill, so the batches gen
# 1 token at a time
# - thus hidden will be (bs, num_heads, head_dim)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using "mamba_cache_params.state_indices_tensor", just as
# above in the prefill case
# using state_indices_tensor_d

hidden_states = selective_state_update(
hidden_states_d = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states_reshaped,
dt,
A,
B,
C,
D,
hidden_states_d,
dt_d,
A_d,
B_d,
C_d,
D_d,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor,
state_batch_indices=state_indices_tensor_d,
)
hidden_states = hidden_states.view(
-1, (self.num_heads // self.tp_size) * self.head_dim)
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))

# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(ssd_output_list)

# # 4. gated MLP
# 4. gated MLP
hidden_states = self.norm(hidden_states, gate)

# # 5. Final linear projection
# 5. Final linear projection
out, _ = self.out_proj(hidden_states)
return out
1 change: 0 additions & 1 deletion vllm/model_executor/layers/mamba/ops/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def _mamba_chunk_scan_combined_fwd(x,
_, _, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate)
assert x.shape == (batch, seqlen, nheads, headdim)
assert dt.shape == (batch, seqlen, nheads)
assert A.shape == (nheads, )
assert C.shape == B.shape
Expand Down
Loading