Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 24 additions & 40 deletions vllm/model_executor/layers/mamba/mamba2_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@

@dataclass
class Mamba2Metadata:

has_initial_states: torch.Tensor
prep_initial_states: bool

chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor

has_initial_states_p: torch.Tensor
seq_idx_p: torch.Tensor
chunk_indices_p: torch.Tensor
chunk_offsets_p: torch.Tensor
"""
With continuous batching layout of `x` in vLLM, to enable a Triton program
to handle a request in parallel, two supporting tensors are used
Expand Down Expand Up @@ -68,19 +67,18 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
def prepare_mamba2_metadata(
chunk_size: int,
attn_metadata: AttentionMetadata,
mamba2_metadata=None,
) -> Mamba2Metadata:

# compute number of prefill and decode requests
# NOTE: in V0 we assume prefills are before decodes
num_prefills = attn_metadata.num_prefills
num_prefill_tokens = attn_metadata.num_prefill_tokens

seq_idx = None
chunk_indices, chunk_offsets = None, None
seq_idx_p = None
chunk_indices_p, chunk_offsets_p = None, None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states = None
has_initial_states_p = None
prep_initial_states = False

# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
Expand All @@ -91,44 +89,30 @@ def prepare_mamba2_metadata(
# precompute flag to avoid device syncs later in mamba2 layer
# forwards
# prep is only needed for mamba2 ssd prefill processing
has_initial_states = attn_metadata.context_lens_tensor > 0
prep_initial_states = torch.any(
has_initial_states[:num_prefills]).item()
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device),
query_start_loc.diff(),
output_size=num_prefill_tokens)
seq_idx.unsqueeze_(0)
has_initial_states_p = (
attn_metadata.context_lens_tensor[:num_prefills] > 0)
prep_initial_states = torch.any(has_initial_states_p).item()
query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx_p = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=num_prefill_tokens)
seq_idx_p.unsqueeze_(0)

# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
if prep_initial_states:
chunk_indices, chunk_offsets = \
chunk_indices_p, chunk_offsets_p = \
_query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens)

if mamba2_metadata is not None:
mamba2_metadata.has_initial_states = has_initial_states
mamba2_metadata.prep_initial_states = prep_initial_states
mamba2_metadata.chunk_size = chunk_size
mamba2_metadata.seq_idx = seq_idx
mamba2_metadata.chunk_indices = chunk_indices
mamba2_metadata.chunk_offsets = chunk_offsets
# We use 1 reset flag:
# * mamba2_metadata.cu_seqlen is None
# update config specific to (each input)
# (become available at first layer, e.g. conv_weights)
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input

return mamba2_metadata
return Mamba2Metadata(has_initial_states=has_initial_states,
query_start_loc_p, chunk_size, num_prefill_tokens)

return Mamba2Metadata(has_initial_states_p=has_initial_states_p,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets)
seq_idx_p=seq_idx_p,
chunk_indices_p=chunk_indices_p,
chunk_offsets_p=chunk_offsets_p)


def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
Expand Down
29 changes: 10 additions & 19 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,22 +478,19 @@ def forward_cuda(
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
has_initial_states_p = mamba2_metadata.has_initial_states

# Common members between V1 metadata and V0 metadata
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx
chunk_indices_p = mamba2_metadata.chunk_indices
chunk_offsets_p = mamba2_metadata.chunk_offsets
seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets_p

groups_time_state_size = self.n_groups * self.ssm_state_size

Expand Down Expand Up @@ -639,15 +636,9 @@ def forward_cuda(
# 3. State Space Model sequence transformation
initial_states = None
if (has_initial_states_p is not None and prep_initial_states):
# making a copy of the states
if envs.VLLM_USE_V1:
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
else:
initial_states = torch.where(
has_initial_states_p[:num_prefills, None, None, None],
ssm_state[state_indices_tensor_p], 0)
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)

# NOTE: final output is an in-place update of out tensor
varlen_state = mamba_chunk_scan_combined(
Expand Down
29 changes: 11 additions & 18 deletions vllm/model_executor/models/plamo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,22 +279,19 @@ def forward_cuda(
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
has_initial_states_p = mamba2_metadata.has_initial_states

# Common members between V1 metadata and V0 metadata
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx
chunk_indices_p = mamba2_metadata.chunk_indices
chunk_offsets_p = mamba2_metadata.chunk_offsets
seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets_p

# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)
Expand Down Expand Up @@ -414,14 +411,10 @@ def forward_cuda(
initial_states = None
if has_initial_states_p is not None and prep_initial_states:
# making a copy of the states
if envs.VLLM_USE_V1:
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
else:
initial_states = torch.where(
has_initial_states_p[:num_prefills, None, None, None],
ssm_state[state_indices_tensor_p], 0)
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)

varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size,
Expand Down