3030from vllm .model_executor .layers .logits_processor import LogitsProcessor
3131from vllm .model_executor .layers .mamba .abstract import MambaBase
3232from vllm .model_executor .layers .mamba .mamba2_metadata import (
33- Mamba2Metadata , prepare_mamba2_metadata , update_metadata )
33+ Mamba2Metadata , prepare_mamba2_metadata )
3434from vllm .model_executor .layers .mamba .mamba_utils import (
3535 MambaStateDtypeCalculator , MambaStateShapeCalculator )
3636from vllm .model_executor .layers .mamba .ops .causal_conv1d import (
@@ -285,6 +285,7 @@ def forward_cuda(
285285 seq_idx_p = attn_metadata .seq_idx_p
286286 chunk_indices_p = attn_metadata .chunk_indices_p
287287 chunk_offsets_p = attn_metadata .chunk_offsets_p
288+ query_start_loc_p = attn_metadata .query_start_loc_p
288289 else :
289290 conv_state = mamba_cache_params .conv_state
290291 ssm_state = mamba_cache_params .ssm_state
@@ -295,6 +296,7 @@ def forward_cuda(
295296 seq_idx_p = mamba2_metadata .seq_idx
296297 chunk_indices_p = mamba2_metadata .chunk_indices
297298 chunk_offsets_p = mamba2_metadata .chunk_offsets
299+ query_start_loc_p = mamba2_metadata .query_start_loc_p
298300
299301 # 1. Gated MLP's linear projection
300302 projected_states = self .in_proj (hidden_states )
@@ -336,9 +338,6 @@ def forward_cuda(
336338 [num_decodes , num_prefills ],
337339 dim = 0 ,
338340 )
339- query_start_loc_p = (
340- attn_metadata .query_start_loc [- num_prefills - 1 :] -
341- num_decodes if has_prefill else None )
342341 else :
343342 hidden_states_p , hidden_states_d = torch .split (
344343 hidden_states ,
@@ -354,9 +353,6 @@ def forward_cuda(
354353 [num_prefills , num_decodes ],
355354 dim = 0 ,
356355 )
357- query_start_loc_p = (attn_metadata .query_start_loc [:num_prefills +
358- 1 ]
359- if has_prefill else None )
360356
361357 # Preallocate output tensor to avoid memcpy cost for merging prefill
362358 # and decode outputs
@@ -388,9 +384,6 @@ def forward_cuda(
388384 # pointed to by "state_indices_tensor"
389385 x = hidden_states_p .transpose (
390386 0 , 1 ) # this is the form that causal-conv see
391- if mamba2_metadata .cu_seqlen is None :
392- mamba2_metadata = update_metadata (x , query_start_loc_p ,
393- mamba2_metadata )
394387 hidden_states_p = causal_conv1d_fn (
395388 x ,
396389 conv_weights ,
0 commit comments