Skip to content

Commit 6db94a4

Browse files
committed
fix query_start_loc_p affected by metadata refactor
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
1 parent d6aa941 commit 6db94a4

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

vllm/model_executor/models/plamo2.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3131
from vllm.model_executor.layers.mamba.abstract import MambaBase
3232
from vllm.model_executor.layers.mamba.mamba2_metadata import (
33-
Mamba2Metadata, prepare_mamba2_metadata, update_metadata)
33+
Mamba2Metadata, prepare_mamba2_metadata)
3434
from vllm.model_executor.layers.mamba.mamba_utils import (
3535
MambaStateDtypeCalculator, MambaStateShapeCalculator)
3636
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
@@ -285,6 +285,7 @@ def forward_cuda(
285285
seq_idx_p = attn_metadata.seq_idx_p
286286
chunk_indices_p = attn_metadata.chunk_indices_p
287287
chunk_offsets_p = attn_metadata.chunk_offsets_p
288+
query_start_loc_p = attn_metadata.query_start_loc_p
288289
else:
289290
conv_state = mamba_cache_params.conv_state
290291
ssm_state = mamba_cache_params.ssm_state
@@ -295,6 +296,7 @@ def forward_cuda(
295296
seq_idx_p = mamba2_metadata.seq_idx
296297
chunk_indices_p = mamba2_metadata.chunk_indices
297298
chunk_offsets_p = mamba2_metadata.chunk_offsets
299+
query_start_loc_p = mamba2_metadata.query_start_loc_p
298300

299301
# 1. Gated MLP's linear projection
300302
projected_states = self.in_proj(hidden_states)
@@ -336,9 +338,6 @@ def forward_cuda(
336338
[num_decodes, num_prefills],
337339
dim=0,
338340
)
339-
query_start_loc_p = (
340-
attn_metadata.query_start_loc[-num_prefills - 1:] -
341-
num_decodes if has_prefill else None)
342341
else:
343342
hidden_states_p, hidden_states_d = torch.split(
344343
hidden_states,
@@ -354,9 +353,6 @@ def forward_cuda(
354353
[num_prefills, num_decodes],
355354
dim=0,
356355
)
357-
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
358-
1]
359-
if has_prefill else None)
360356

361357
# Preallocate output tensor to avoid memcpy cost for merging prefill
362358
# and decode outputs
@@ -388,9 +384,6 @@ def forward_cuda(
388384
# pointed to by "state_indices_tensor"
389385
x = hidden_states_p.transpose(
390386
0, 1) # this is the form that causal-conv see
391-
if mamba2_metadata.cu_seqlen is None:
392-
mamba2_metadata = update_metadata(x, query_start_loc_p,
393-
mamba2_metadata)
394387
hidden_states_p = causal_conv1d_fn(
395388
x,
396389
conv_weights,

0 commit comments

Comments
 (0)