1717
1818@dataclass
1919class Mamba2Metadata :
20-
21- has_initial_states : torch .Tensor
2220 prep_initial_states : bool
23-
2421 chunk_size : int
25- seq_idx : torch .Tensor
26- chunk_indices : torch .Tensor
27- chunk_offsets : torch .Tensor
22+
23+ has_initial_states_p : torch .Tensor
24+ seq_idx_p : torch .Tensor
25+ chunk_indices_p : torch .Tensor
26+ chunk_offsets_p : torch .Tensor
2827 """
2928 With continuous batching layout of `x` in vLLM, to enable a Triton program
3029 to handle a request in parallel, two supporting tensors are used
@@ -68,19 +67,18 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
6867def prepare_mamba2_metadata (
6968 chunk_size : int ,
7069 attn_metadata : AttentionMetadata ,
71- mamba2_metadata = None ,
7270) -> Mamba2Metadata :
7371
7472 # compute number of prefill and decode requests
7573 # NOTE: in V0 we assume prefills are before decodes
7674 num_prefills = attn_metadata .num_prefills
7775 num_prefill_tokens = attn_metadata .num_prefill_tokens
7876
79- seq_idx = None
80- chunk_indices , chunk_offsets = None , None
77+ seq_idx_p = None
78+ chunk_indices_p , chunk_offsets_p = None , None
8179 # Need flags to indicate if there are initial states
8280 # currently we really only support the FlashAttention backend
83- has_initial_states = None
81+ has_initial_states_p = None
8482 prep_initial_states = False
8583
8684 # Compute seq_idx, chunk_indices and chunk_offsets for prefill only
@@ -91,44 +89,30 @@ def prepare_mamba2_metadata(
9189 # precompute flag to avoid device syncs later in mamba2 layer
9290 # forwards
9391 # prep is only needed for mamba2 ssd prefill processing
94- has_initial_states = attn_metadata . context_lens_tensor > 0
95- prep_initial_states = torch . any (
96- has_initial_states [: num_prefills ] ).item ()
97- query_start_loc = attn_metadata .query_start_loc [:num_prefills + 1 ]
98- seq_idx = torch .repeat_interleave (torch .arange (
99- num_prefills , dtype = torch .int32 , device = query_start_loc .device ),
100- query_start_loc .diff (),
101- output_size = num_prefill_tokens )
102- seq_idx .unsqueeze_ (0 )
92+ has_initial_states_p = (
93+ attn_metadata . context_lens_tensor [: num_prefills ] > 0 )
94+ prep_initial_states = torch . any ( has_initial_states_p ).item ()
95+ query_start_loc_p = attn_metadata .query_start_loc [:num_prefills + 1 ]
96+ seq_idx_p = torch .repeat_interleave (torch .arange (
97+ num_prefills , dtype = torch .int32 , device = query_start_loc_p .device ),
98+ query_start_loc_p .diff (),
99+ output_size = num_prefill_tokens )
100+ seq_idx_p .unsqueeze_ (0 )
103101
104102 # We compute metadata for chunked prefill once at the top level model
105103 # forward and reuse them in mamba layers. If not needed, they will be
106104 # ignored inside mamba kernels.
107105 if prep_initial_states :
108- chunk_indices , chunk_offsets = \
106+ chunk_indices_p , chunk_offsets_p = \
109107 _query_start_loc_to_chunk_indices_offsets (
110- query_start_loc , chunk_size , num_prefill_tokens )
111-
112- if mamba2_metadata is not None :
113- mamba2_metadata .has_initial_states = has_initial_states
114- mamba2_metadata .prep_initial_states = prep_initial_states
115- mamba2_metadata .chunk_size = chunk_size
116- mamba2_metadata .seq_idx = seq_idx
117- mamba2_metadata .chunk_indices = chunk_indices
118- mamba2_metadata .chunk_offsets = chunk_offsets
119- # We use 1 reset flag:
120- # * mamba2_metadata.cu_seqlen is None
121- # update config specific to (each input)
122- # (become available at first layer, e.g. conv_weights)
123- mamba2_metadata .cu_seqlen = None # suppose to be updated at each input
124-
125- return mamba2_metadata
126- return Mamba2Metadata (has_initial_states = has_initial_states ,
108+ query_start_loc_p , chunk_size , num_prefill_tokens )
109+
110+ return Mamba2Metadata (has_initial_states_p = has_initial_states_p ,
127111 prep_initial_states = prep_initial_states ,
128112 chunk_size = chunk_size ,
129- seq_idx = seq_idx ,
130- chunk_indices = chunk_indices ,
131- chunk_offsets = chunk_offsets )
113+ seq_idx_p = seq_idx_p ,
114+ chunk_indices_p = chunk_indices_p ,
115+ chunk_offsets_p = chunk_offsets_p )
132116
133117
134118def update_metadata (x : torch .Tensor , query_start_loc : torch .Tensor ,
0 commit comments