@@ -514,28 +514,17 @@ def cuda_kernels_forward(
514514 self ,
515515 hidden_states : torch .Tensor ,
516516 cache_params : Optional [HybridMambaAttentionDynamicCache ] = None ,
517- cache_position : Optional [torch .LongTensor ] = None ,
518517 attention_mask : Optional [torch .Tensor ] = None ,
518+ seq_idx : Optional [torch .Tensor ] = None ,
519+ use_precomputed_states : bool = False ,
519520 ):
520521 # 1. Gated MLP's linear projection
521- hidden_states = apply_mask_to_padding_states (hidden_states , attention_mask )
522522 projected_states = self .in_proj (hidden_states )
523523
524524 # Set up dimensions for reshapes later
525525 batch_size , seq_len , _ = hidden_states .shape
526526 groups_time_state_size = self .n_groups * self .ssm_state_size
527527
528- use_precomputed_states = (
529- cache_params is not None
530- and cache_params .has_previous_state
531- and seq_len == 1
532- and cache_params .conv_states [self .layer_idx ].shape [0 ]
533- == cache_params .ssm_states [self .layer_idx ].shape [0 ]
534- == batch_size
535- and cache_position is not None
536- and cache_position [0 ] > 0
537- )
538-
539528 # getting projected states from cache if it exists
540529 if use_precomputed_states :
541530 gate , hidden_states_B_C , dt = projected_states .squeeze (1 ).split (
@@ -598,7 +587,7 @@ def cuda_kernels_forward(
598587 A ,
599588 D = self .D ,
600589 chunk_size = self .chunk_size ,
601- seq_idx = None , # was seq_idx
590+ seq_idx = seq_idx ,
602591 activation = self .activation ,
603592 rmsnorm_weight = self .norm .weight ,
604593 rmsnorm_eps = self .norm .variance_epsilon ,
@@ -682,29 +671,18 @@ def torch_forward(
682671 self ,
683672 input_states ,
684673 cache_params : Optional [HybridMambaAttentionDynamicCache ] = None ,
685- cache_position : Optional [torch .LongTensor ] = None ,
686674 attention_mask : Optional [torch .Tensor ] = None ,
675+ use_precomputed_states : bool = False
687676 ):
688677 batch_size , seq_len , _ = input_states .shape
689678 dtype = input_states .dtype
690679
691680 # 1. Gated MLP's linear projection
692- input_states = apply_mask_to_padding_states (input_states , attention_mask )
693681 projected_states = self .in_proj (input_states )
694682 gate , hidden_states_B_C , dt = projected_states .split (
695683 [self .intermediate_size , self .conv_dim , self .num_heads ], dim = - 1
696684 )
697685
698- use_precomputed_states = (
699- cache_params is not None
700- and cache_params .has_previous_state
701- and seq_len == 1
702- and cache_params .conv_states [self .layer_idx ].shape [0 ]
703- == cache_params .ssm_states [self .layer_idx ].shape [0 ]
704- == batch_size
705- and cache_position is not None
706- and cache_position [0 ] > 0
707- )
708686
709687 # 2. Convolution sequence transformation
710688 if use_precomputed_states :
@@ -891,15 +869,27 @@ def forward(
891869 cache_params : Optional [HybridMambaAttentionDynamicCache ] = None ,
892870 cache_position : Optional [torch .LongTensor ] = None ,
893871 attention_mask : Optional [torch .Tensor ] = None ,
872+ seq_idx : Optional [torch .Tensor ] = None ,
894873 ):
874+ batch_size , seq_len , _ = hidden_states .shape
875+ use_precomputed_states = (
876+ cache_params is not None
877+ and cache_params .has_previous_state
878+ and seq_len == 1
879+ and cache_params .conv_states [self .layer_idx ].shape [0 ]
880+ == cache_params .ssm_states [self .layer_idx ].shape [0 ]
881+ == batch_size
882+ and cache_position is not None
883+ and cache_position [0 ] > 0
884+ )
885+ hidden_states = apply_mask_to_padding_states (hidden_states , attention_mask )
895886 if is_fast_path_available and "cuda" in self .in_proj .weight .device .type :
896- return self .cuda_kernels_forward (hidden_states , cache_params , cache_position , attention_mask )
897- dtype = hidden_states .dtype
898- if attention_mask is not None and attention_mask .shape [1 ] > 1 and attention_mask .shape [0 ] > 1 :
899- # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
900- hidden_states = (hidden_states * attention_mask [:, :, None ]).to (dtype )
901-
902- return self .torch_forward (hidden_states , cache_params , cache_position , attention_mask )
887+ return self .cuda_kernels_forward (
888+ hidden_states , cache_params , attention_mask , seq_idx , use_precomputed_states
889+ )
890+ if seq_idx is not None :
891+ raise ValueError ("Non-trivial seq_idx only supported on cuda path." )
892+ return self .torch_forward (hidden_states , cache_params , attention_mask , use_precomputed_states )
903893
904894
905895class BambaMLP (nn .Module ):
@@ -938,10 +928,42 @@ def extra_repr(self):
938928 return f"{ tuple (self .weight .shape )} , eps={ self .variance_epsilon } "
939929
940930
931+ def get_cu_seq_lens_from_position_ids (position_ids : torch .LongTensor ) -> torch .LongTensor :
932+ batch_size = position_ids .shape [0 ]
933+ if batch_size != 1 :
934+ raise ValueError ("Only batch size 1 is supported." )
935+ device = position_ids .device
936+ idxs = torch .arange (1 , position_ids .shape [1 ], device = device )
937+ non_increasing_pos_id = position_ids [0 , 1 :] <= position_ids [0 , :- 1 ]
938+ cu_seq_lens = torch .cat (
939+ (
940+ torch .tensor ([0 ], device = device ),
941+ idxs [non_increasing_pos_id ],
942+ torch .tensor (position_ids [0 ].shape , device = device ),
943+ ),
944+ )
945+ return cu_seq_lens [None ]
946+
947+
948+ def get_seq_idx_from_cu_seq_lens (cu_seq_lens : torch .Tensor ) -> torch .Tensor :
949+ batch_size = cu_seq_lens .shape [0 ]
950+ if batch_size != 1 :
951+ raise ValueError ("Only batch size 1 is supported." )
952+ seq_idx = torch .cat (
953+ [
954+ torch .full ((n ,), idx , dtype = torch .int32 , device = cu_seq_lens .device )
955+ for idx , n in enumerate (torch .diff (cu_seq_lens [0 ], dim = - 1 ))
956+ ]
957+ )
958+ return seq_idx [None ]
959+
960+
941961class BambaDecoderLayer (nn .Module ):
942962 def __init__ (self , config : BambaConfig , layer_idx : int , layer_type : str = "mamba" ):
943963 super ().__init__ ()
944964
965+ # The `num_experts` code below is redundant, but it prevents modular_model_converter.py from
966+ # generating an unwanted BambaSparseMoeBlock in modeling_bamba.py
945967 num_experts = 1
946968 ffn_layer_class = BambaMLP if num_experts == 1 else None
947969 self .feed_forward = ffn_layer_class (config )
@@ -966,7 +988,7 @@ def forward(
966988 use_cache : Optional [bool ] = False ,
967989 cache_position : Optional [torch .LongTensor ] = None ,
968990 position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None , # necessary, but kept here for BC
969- ** kwargs ,
991+ ** kwargs : Unpack [ FlashAttentionKwargs ] ,
970992 ) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
971993 """
972994 Args:
@@ -996,11 +1018,29 @@ def forward(
9961018
9971019 # this is a hybrid decoder layer
9981020 if self .layer_type == "mamba" :
1021+ # Padding-free processing for efficient training. position_ids and FlashAttentionKwargs
1022+ # are ignored by mamba layers if not training.
1023+ if not self .training :
1024+ seq_idx = None
1025+ elif "cu_seq_lens_k" in kwargs :
1026+ seq_idx = get_seq_idx_from_cu_seq_lens (kwargs ["cu_seq_lens_k" ])
1027+ elif position_ids is not None :
1028+ cu_seq_lens = get_cu_seq_lens_from_position_ids (position_ids )
1029+ if len (cu_seq_lens [0 ]) == 2 :
1030+ # If cu_seq_lens only has two elements, then it is semantically equivalent to
1031+ # `seq_idx=None`, which is more efficient.
1032+ seq_idx = None
1033+ else :
1034+ seq_idx = get_seq_idx_from_cu_seq_lens (cu_seq_lens )
1035+ else :
1036+ seq_idx = None
9991037 hidden_states = self .mamba (
10001038 hidden_states = hidden_states ,
10011039 cache_params = past_key_value ,
10021040 cache_position = cache_position ,
10031041 attention_mask = attention_mask ,
1042+ seq_idx = seq_idx ,
1043+ ** kwargs ,
10041044 )
10051045 self_attn_weights = None
10061046 elif self .layer_type == "attention" :
@@ -1200,6 +1240,7 @@ def forward(
12001240 output_hidden_states : Optional [bool ] = None ,
12011241 return_dict : Optional [bool ] = None ,
12021242 cache_position : Optional [torch .LongTensor ] = None ,
1243+ ** flash_attn_kwargs : Unpack [FlashAttentionKwargs ],
12031244 ) -> Union [Tuple , BaseModelOutputWithPast ]:
12041245 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
12051246 output_hidden_states = (
@@ -1273,6 +1314,7 @@ def forward(
12731314 use_cache = use_cache ,
12741315 cache_position = cache_position ,
12751316 position_embeddings = position_embeddings ,
1317+ ** flash_attn_kwargs ,
12761318 )
12771319
12781320 hidden_states = layer_outputs [0 ]
0 commit comments