Skip to content

Commit cdaf1e6

Browse files
committed
Add padding-free to bamba
1 parent 34f76bb commit cdaf1e6

File tree

3 files changed

+218
-72
lines changed

3 files changed

+218
-72
lines changed

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 76 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -514,28 +514,17 @@ def cuda_kernels_forward(
514514
self,
515515
hidden_states: torch.Tensor,
516516
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
517-
cache_position: Optional[torch.LongTensor] = None,
518517
attention_mask: Optional[torch.Tensor] = None,
518+
seq_idx: Optional[torch.Tensor] = None,
519+
use_precomputed_states: bool = False,
519520
):
520521
# 1. Gated MLP's linear projection
521-
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
522522
projected_states = self.in_proj(hidden_states)
523523

524524
# Set up dimensions for reshapes later
525525
batch_size, seq_len, _ = hidden_states.shape
526526
groups_time_state_size = self.n_groups * self.ssm_state_size
527527

528-
use_precomputed_states = (
529-
cache_params is not None
530-
and cache_params.has_previous_state
531-
and seq_len == 1
532-
and cache_params.conv_states[self.layer_idx].shape[0]
533-
== cache_params.ssm_states[self.layer_idx].shape[0]
534-
== batch_size
535-
and cache_position is not None
536-
and cache_position[0] > 0
537-
)
538-
539528
# getting projected states from cache if it exists
540529
if use_precomputed_states:
541530
gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
@@ -598,7 +587,7 @@ def cuda_kernels_forward(
598587
A,
599588
D=self.D,
600589
chunk_size=self.chunk_size,
601-
seq_idx=None, # was seq_idx
590+
seq_idx=seq_idx,
602591
activation=self.activation,
603592
rmsnorm_weight=self.norm.weight,
604593
rmsnorm_eps=self.norm.variance_epsilon,
@@ -682,29 +671,18 @@ def torch_forward(
682671
self,
683672
input_states,
684673
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
685-
cache_position: Optional[torch.LongTensor] = None,
686674
attention_mask: Optional[torch.Tensor] = None,
675+
use_precomputed_states: bool = False
687676
):
688677
batch_size, seq_len, _ = input_states.shape
689678
dtype = input_states.dtype
690679

691680
# 1. Gated MLP's linear projection
692-
input_states = apply_mask_to_padding_states(input_states, attention_mask)
693681
projected_states = self.in_proj(input_states)
694682
gate, hidden_states_B_C, dt = projected_states.split(
695683
[self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
696684
)
697685

698-
use_precomputed_states = (
699-
cache_params is not None
700-
and cache_params.has_previous_state
701-
and seq_len == 1
702-
and cache_params.conv_states[self.layer_idx].shape[0]
703-
== cache_params.ssm_states[self.layer_idx].shape[0]
704-
== batch_size
705-
and cache_position is not None
706-
and cache_position[0] > 0
707-
)
708686

709687
# 2. Convolution sequence transformation
710688
if use_precomputed_states:
@@ -891,15 +869,27 @@ def forward(
891869
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
892870
cache_position: Optional[torch.LongTensor] = None,
893871
attention_mask: Optional[torch.Tensor] = None,
872+
seq_idx: Optional[torch.Tensor] = None,
894873
):
874+
batch_size, seq_len, _ = hidden_states.shape
875+
use_precomputed_states = (
876+
cache_params is not None
877+
and cache_params.has_previous_state
878+
and seq_len == 1
879+
and cache_params.conv_states[self.layer_idx].shape[0]
880+
== cache_params.ssm_states[self.layer_idx].shape[0]
881+
== batch_size
882+
and cache_position is not None
883+
and cache_position[0] > 0
884+
)
885+
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
895886
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
896-
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
897-
dtype = hidden_states.dtype
898-
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
899-
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
900-
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
901-
902-
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
887+
return self.cuda_kernels_forward(
888+
hidden_states, cache_params, attention_mask, seq_idx, use_precomputed_states
889+
)
890+
if seq_idx is not None:
891+
raise ValueError("Non-trivial seq_idx only supported on cuda path.")
892+
return self.torch_forward(hidden_states, cache_params, attention_mask, use_precomputed_states)
903893

904894

905895
class BambaMLP(nn.Module):
@@ -938,10 +928,42 @@ def extra_repr(self):
938928
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
939929

940930

931+
def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
932+
batch_size = position_ids.shape[0]
933+
if batch_size != 1:
934+
raise ValueError("Only batch size 1 is supported.")
935+
device = position_ids.device
936+
idxs = torch.arange(1, position_ids.shape[1], device=device)
937+
non_increasing_pos_id = position_ids[0, 1:] <= position_ids[0, :-1]
938+
cu_seq_lens = torch.cat(
939+
(
940+
torch.tensor([0], device=device),
941+
idxs[non_increasing_pos_id],
942+
torch.tensor(position_ids[0].shape, device=device),
943+
),
944+
)
945+
return cu_seq_lens[None]
946+
947+
948+
def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
949+
batch_size = cu_seq_lens.shape[0]
950+
if batch_size != 1:
951+
raise ValueError("Only batch size 1 is supported.")
952+
seq_idx = torch.cat(
953+
[
954+
torch.full((n,), idx, dtype=torch.int32, device=cu_seq_lens.device)
955+
for idx, n in enumerate(torch.diff(cu_seq_lens[0], dim=-1))
956+
]
957+
)
958+
return seq_idx[None]
959+
960+
941961
class BambaDecoderLayer(nn.Module):
942962
def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
943963
super().__init__()
944964

965+
# The `num_experts` code below is redundant, but it prevents modular_model_converter.py from
966+
# generating an unwanted BambaSparseMoeBlock in modeling_bamba.py
945967
num_experts = 1
946968
ffn_layer_class = BambaMLP if num_experts == 1 else None
947969
self.feed_forward = ffn_layer_class(config)
@@ -966,7 +988,7 @@ def forward(
966988
use_cache: Optional[bool] = False,
967989
cache_position: Optional[torch.LongTensor] = None,
968990
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
969-
**kwargs,
991+
**kwargs: Unpack[FlashAttentionKwargs],
970992
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
971993
"""
972994
Args:
@@ -996,11 +1018,29 @@ def forward(
9961018

9971019
# this is a hybrid decoder layer
9981020
if self.layer_type == "mamba":
1021+
# Padding-free processing for efficient training. position_ids and FlashAttentionKwargs
1022+
# are ignored by mamba layers if not training.
1023+
if not self.training:
1024+
seq_idx = None
1025+
elif "cu_seq_lens_k" in kwargs:
1026+
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"])
1027+
elif position_ids is not None:
1028+
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids)
1029+
if len(cu_seq_lens[0]) == 2:
1030+
# If cu_seq_lens only has two elements, then it is semantically equivalent to
1031+
# `seq_idx=None`, which is more efficient.
1032+
seq_idx = None
1033+
else:
1034+
seq_idx = get_seq_idx_from_cu_seq_lens(cu_seq_lens)
1035+
else:
1036+
seq_idx = None
9991037
hidden_states = self.mamba(
10001038
hidden_states=hidden_states,
10011039
cache_params=past_key_value,
10021040
cache_position=cache_position,
10031041
attention_mask=attention_mask,
1042+
seq_idx=seq_idx,
1043+
**kwargs,
10041044
)
10051045
self_attn_weights = None
10061046
elif self.layer_type == "attention":
@@ -1200,6 +1240,7 @@ def forward(
12001240
output_hidden_states: Optional[bool] = None,
12011241
return_dict: Optional[bool] = None,
12021242
cache_position: Optional[torch.LongTensor] = None,
1243+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
12031244
) -> Union[Tuple, BaseModelOutputWithPast]:
12041245
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
12051246
output_hidden_states = (
@@ -1273,6 +1314,7 @@ def forward(
12731314
use_cache=use_cache,
12741315
cache_position=cache_position,
12751316
position_embeddings=position_embeddings,
1317+
**flash_attn_kwargs,
12761318
)
12771319

12781320
hidden_states = layer_outputs[0]

0 commit comments

Comments
 (0)