Skip to content

Commit 73cfb3c

Browse files
authored
[Model] Clean up and simplify Mamba2 Metadata Usage in both V0 and V1 (#24331)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
1 parent 4e5affe commit 73cfb3c

File tree

3 files changed

+45
-77
lines changed

3 files changed

+45
-77
lines changed

vllm/model_executor/layers/mamba/mamba2_metadata.py

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717

1818
@dataclass
1919
class Mamba2Metadata:
20-
21-
has_initial_states: torch.Tensor
2220
prep_initial_states: bool
23-
2421
chunk_size: int
25-
seq_idx: torch.Tensor
26-
chunk_indices: torch.Tensor
27-
chunk_offsets: torch.Tensor
22+
23+
has_initial_states_p: torch.Tensor
24+
seq_idx_p: torch.Tensor
25+
chunk_indices_p: torch.Tensor
26+
chunk_offsets_p: torch.Tensor
2827
"""
2928
With continuous batching layout of `x` in vLLM, to enable a Triton program
3029
to handle a request in parallel, two supporting tensors are used
@@ -68,19 +67,18 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
6867
def prepare_mamba2_metadata(
6968
chunk_size: int,
7069
attn_metadata: AttentionMetadata,
71-
mamba2_metadata=None,
7270
) -> Mamba2Metadata:
7371

7472
# compute number of prefill and decode requests
7573
# NOTE: in V0 we assume prefills are before decodes
7674
num_prefills = attn_metadata.num_prefills
7775
num_prefill_tokens = attn_metadata.num_prefill_tokens
7876

79-
seq_idx = None
80-
chunk_indices, chunk_offsets = None, None
77+
seq_idx_p = None
78+
chunk_indices_p, chunk_offsets_p = None, None
8179
# Need flags to indicate if there are initial states
8280
# currently we really only support the FlashAttention backend
83-
has_initial_states = None
81+
has_initial_states_p = None
8482
prep_initial_states = False
8583

8684
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
@@ -91,44 +89,30 @@ def prepare_mamba2_metadata(
9189
# precompute flag to avoid device syncs later in mamba2 layer
9290
# forwards
9391
# prep is only needed for mamba2 ssd prefill processing
94-
has_initial_states = attn_metadata.context_lens_tensor > 0
95-
prep_initial_states = torch.any(
96-
has_initial_states[:num_prefills]).item()
97-
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
98-
seq_idx = torch.repeat_interleave(torch.arange(
99-
num_prefills, dtype=torch.int32, device=query_start_loc.device),
100-
query_start_loc.diff(),
101-
output_size=num_prefill_tokens)
102-
seq_idx.unsqueeze_(0)
92+
has_initial_states_p = (
93+
attn_metadata.context_lens_tensor[:num_prefills] > 0)
94+
prep_initial_states = torch.any(has_initial_states_p).item()
95+
query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1]
96+
seq_idx_p = torch.repeat_interleave(torch.arange(
97+
num_prefills, dtype=torch.int32, device=query_start_loc_p.device),
98+
query_start_loc_p.diff(),
99+
output_size=num_prefill_tokens)
100+
seq_idx_p.unsqueeze_(0)
103101

104102
# We compute metadata for chunked prefill once at the top level model
105103
# forward and reuse them in mamba layers. If not needed, they will be
106104
# ignored inside mamba kernels.
107105
if prep_initial_states:
108-
chunk_indices, chunk_offsets = \
106+
chunk_indices_p, chunk_offsets_p = \
109107
_query_start_loc_to_chunk_indices_offsets(
110-
query_start_loc, chunk_size, num_prefill_tokens)
111-
112-
if mamba2_metadata is not None:
113-
mamba2_metadata.has_initial_states = has_initial_states
114-
mamba2_metadata.prep_initial_states = prep_initial_states
115-
mamba2_metadata.chunk_size = chunk_size
116-
mamba2_metadata.seq_idx = seq_idx
117-
mamba2_metadata.chunk_indices = chunk_indices
118-
mamba2_metadata.chunk_offsets = chunk_offsets
119-
# We use 1 reset flag:
120-
# * mamba2_metadata.cu_seqlen is None
121-
# update config specific to (each input)
122-
# (become available at first layer, e.g. conv_weights)
123-
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input
124-
125-
return mamba2_metadata
126-
return Mamba2Metadata(has_initial_states=has_initial_states,
108+
query_start_loc_p, chunk_size, num_prefill_tokens)
109+
110+
return Mamba2Metadata(has_initial_states_p=has_initial_states_p,
127111
prep_initial_states=prep_initial_states,
128112
chunk_size=chunk_size,
129-
seq_idx=seq_idx,
130-
chunk_indices=chunk_indices,
131-
chunk_offsets=chunk_offsets)
113+
seq_idx_p=seq_idx_p,
114+
chunk_indices_p=chunk_indices_p,
115+
chunk_offsets_p=chunk_offsets_p)
132116

133117

134118
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -518,22 +518,19 @@ def forward_cuda(
518518
conv_state = self_kv_cache[0].transpose(-1, -2)
519519
ssm_state = self_kv_cache[1]
520520
state_indices_tensor = attn_metadata.state_indices_tensor
521-
has_initial_states_p = attn_metadata.has_initial_states_p
522-
prep_initial_states = attn_metadata.prep_initial_states
523-
chunk_size = attn_metadata.chunk_size
524-
seq_idx_p = attn_metadata.seq_idx_p
525-
chunk_indices_p = attn_metadata.chunk_indices_p
526-
chunk_offsets_p = attn_metadata.chunk_offsets_p
527521
else:
528522
conv_state = mamba_cache_params.conv_state
529523
ssm_state = mamba_cache_params.ssm_state
530524
state_indices_tensor = mamba_cache_params.state_indices_tensor
531-
has_initial_states_p = mamba2_metadata.has_initial_states
525+
526+
# Common members between V1 metadata and V0 metadata
527+
if mamba2_metadata is not None:
528+
has_initial_states_p = mamba2_metadata.has_initial_states_p
532529
prep_initial_states = mamba2_metadata.prep_initial_states
533530
chunk_size = mamba2_metadata.chunk_size
534-
seq_idx_p = mamba2_metadata.seq_idx
535-
chunk_indices_p = mamba2_metadata.chunk_indices
536-
chunk_offsets_p = mamba2_metadata.chunk_offsets
531+
seq_idx_p = mamba2_metadata.seq_idx_p
532+
chunk_indices_p = mamba2_metadata.chunk_indices_p
533+
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
537534

538535
# 1. Gated MLP's linear projection
539536
projected_states, _ = self.in_proj(hidden_states)
@@ -677,15 +674,9 @@ def forward_cuda(
677674
# 3. State Space Model sequence transformation
678675
initial_states = None
679676
if (has_initial_states_p is not None and prep_initial_states):
680-
# making a copy of the states
681-
if envs.VLLM_USE_V1:
682-
initial_states = torch.where(
683-
has_initial_states_p[:, None, None, None],
684-
ssm_state[state_indices_tensor_p], 0)
685-
else:
686-
initial_states = torch.where(
687-
has_initial_states_p[:num_prefills, None, None, None],
688-
ssm_state[state_indices_tensor_p], 0)
677+
initial_states = torch.where(
678+
has_initial_states_p[:, None, None, None],
679+
ssm_state[state_indices_tensor_p], 0)
689680

690681
# NOTE: final output is an in-place update of out tensor
691682
varlen_state = mamba_chunk_scan_combined(

vllm/model_executor/models/plamo2.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -279,22 +279,19 @@ def forward_cuda(
279279
conv_state = self_kv_cache[0].transpose(-1, -2)
280280
ssm_state = self_kv_cache[1]
281281
state_indices_tensor = attn_metadata.state_indices_tensor
282-
has_initial_states_p = attn_metadata.has_initial_states_p
283-
prep_initial_states = attn_metadata.prep_initial_states
284-
chunk_size = attn_metadata.chunk_size
285-
seq_idx_p = attn_metadata.seq_idx_p
286-
chunk_indices_p = attn_metadata.chunk_indices_p
287-
chunk_offsets_p = attn_metadata.chunk_offsets_p
288282
else:
289283
conv_state = mamba_cache_params.conv_state
290284
ssm_state = mamba_cache_params.ssm_state
291285
state_indices_tensor = mamba_cache_params.state_indices_tensor
292-
has_initial_states_p = mamba2_metadata.has_initial_states
286+
287+
# Common members between V1 metadata and V0 metadata
288+
if mamba2_metadata is not None:
289+
has_initial_states_p = mamba2_metadata.has_initial_states_p
293290
prep_initial_states = mamba2_metadata.prep_initial_states
294291
chunk_size = mamba2_metadata.chunk_size
295-
seq_idx_p = mamba2_metadata.seq_idx
296-
chunk_indices_p = mamba2_metadata.chunk_indices
297-
chunk_offsets_p = mamba2_metadata.chunk_offsets
292+
seq_idx_p = mamba2_metadata.seq_idx_p
293+
chunk_indices_p = mamba2_metadata.chunk_indices_p
294+
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
298295

299296
# 1. Gated MLP's linear projection
300297
projected_states = self.in_proj(hidden_states)
@@ -414,14 +411,10 @@ def forward_cuda(
414411
initial_states = None
415412
if has_initial_states_p is not None and prep_initial_states:
416413
# making a copy of the states
417-
if envs.VLLM_USE_V1:
418-
initial_states = torch.where(
419-
has_initial_states_p[:, None, None, None],
420-
ssm_state[state_indices_tensor_p], 0)
421-
else:
422-
initial_states = torch.where(
423-
has_initial_states_p[:num_prefills, None, None, None],
424-
ssm_state[state_indices_tensor_p], 0)
414+
initial_states = torch.where(
415+
has_initial_states_p[:, None, None, None],
416+
ssm_state[state_indices_tensor_p], 0)
417+
425418
varlen_state = mamba_chunk_scan_combined(
426419
hidden_states_p.view(1, num_prefill_tokens,
427420
self.num_heads // self.tp_size,

0 commit comments

Comments
 (0)