1919# limitations under the License.
2020"""PyTorch Bamba model."""
2121
22- from typing import Optional , Tuple , Union
22+ from functools import partial
23+ from typing import Optional , Tuple , TypedDict , Union
2324
2425import torch
2526import torch .utils .checkpoint
4647from ...modeling_attn_mask_utils import AttentionMaskConverter
4748from ...modeling_outputs import BaseModelOutputWithPast , CausalLMOutputWithPast
4849from ...modeling_utils import PreTrainedModel
49- from ...utils import auto_docstring , can_return_tuple , logging
50+ from ...processing_utils import Unpack
51+ from ...utils import (
52+ auto_docstring ,
53+ can_return_tuple ,
54+ logging ,
55+ )
5056from ...utils .import_utils import is_causal_conv1d_available , is_flash_attn_2_available , is_mamba_2_ssm_available
5157from .configuration_bamba import BambaConfig
5258
7177logger = logging .get_logger (__name__ )
7278
7379
80+ class BambaFlashAttentionKwargs (TypedDict , total = False ):
81+ """
82+ Keyword arguments for advanced Flash Attention, causal-conv1d, and mamba_ssm kernel usage.
83+ Use cases include padding-free training and fewer `torch.compile` graph breaks.
84+
85+ Attributes:
86+ cu_seq_lens_q (`torch.LongTensor`)
87+ Gets cumulative sequence length for query state.
88+ cu_seq_lens_k (`torch.LongTensor`)
89+ Gets cumulative sequence length for key state.
90+ max_length_q (`int`):
91+ Maximum sequence length for query state.
92+ max_length_k (`int`):
93+ Maximum sequence length for key state.
94+ seq_idx (`torch.IntTensor):
95+ Index of each packed sequence.
96+ """
97+
98+ cu_seq_lens_q : torch .LongTensor
99+ cu_seq_lens_k : torch .LongTensor
100+ max_length_q : int
101+ max_length_k : int
102+ seq_idx : torch .IntTensor
103+
104+
74105# Adapted from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache for the v2 mixer
75106class HybridMambaAttentionDynamicCache (modeling_jamba .HybridMambaAttentionDynamicCache ):
76107 """
@@ -278,6 +309,7 @@ def cuda_kernels_forward(
278309 cache_params : Optional [HybridMambaAttentionDynamicCache ] = None ,
279310 cache_position : Optional [torch .LongTensor ] = None ,
280311 attention_mask : Optional [torch .Tensor ] = None ,
312+ seq_idx : Optional [torch .IntTensor ] = None ,
281313 ):
282314 # 1. Gated MLP's linear projection
283315 hidden_states = apply_mask_to_padding_states (hidden_states , attention_mask )
@@ -360,7 +392,7 @@ def cuda_kernels_forward(
360392 A ,
361393 D = self .D ,
362394 chunk_size = self .chunk_size ,
363- seq_idx = None , # was seq_idx
395+ seq_idx = seq_idx ,
364396 activation = self .activation ,
365397 rmsnorm_weight = self .norm .weight ,
366398 rmsnorm_eps = self .norm .variance_epsilon ,
@@ -401,6 +433,7 @@ def cuda_kernels_forward(
401433 weight = self .conv1d .weight .squeeze (1 ),
402434 bias = self .conv1d .bias ,
403435 activation = self .activation ,
436+ seq_idx = seq_idx ,
404437 ).transpose (1 , 2 )
405438
406439 hidden_states_B_C = apply_mask_to_padding_states (hidden_states_B_C , attention_mask )
@@ -420,7 +453,7 @@ def cuda_kernels_forward(
420453 chunk_size = self .chunk_size ,
421454 D = self .D ,
422455 z = None ,
423- seq_idx = None ,
456+ seq_idx = seq_idx ,
424457 return_final_states = True ,
425458 dt_bias = self .dt_bias ,
426459 dt_softplus = True ,
@@ -654,9 +687,15 @@ def forward(
654687 cache_params : Optional [HybridMambaAttentionDynamicCache ] = None ,
655688 cache_position : Optional [torch .LongTensor ] = None ,
656689 attention_mask : Optional [torch .Tensor ] = None ,
690+ seq_idx : Optional [torch .IntTensor ] = None ,
691+ ** kwargs ,
657692 ):
658693 if is_fast_path_available and "cuda" in self .in_proj .weight .device .type :
659- return self .cuda_kernels_forward (hidden_states , cache_params , cache_position , attention_mask )
694+ return self .cuda_kernels_forward (hidden_states , cache_params , cache_position , attention_mask , seq_idx )
695+ if seq_idx is not None :
696+ raise NotImplementedError (
697+ "`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`"
698+ )
660699 dtype = hidden_states .dtype
661700 if attention_mask is not None and attention_mask .shape [1 ] > 1 and attention_mask .shape [0 ] > 1 :
662701 # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
@@ -701,7 +740,7 @@ def forward(
701740 use_cache : Optional [bool ] = False ,
702741 cache_position : Optional [torch .LongTensor ] = None ,
703742 position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None , # necessary, but kept here for BC
704- ** kwargs ,
743+ ** kwargs : Unpack [ BambaFlashAttentionKwargs ] ,
705744 ) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
706745 """
707746 Args:
@@ -721,8 +760,8 @@ def forward(
721760 Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
722761 with `head_dim` being the embedding dimension of each attention head.
723762 kwargs (`dict`, *optional*):
724- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
725- into the model
763+ Arbitrary kwargs. Can be used to provide `BambaFlashAttentionKwargs` for
764+ padding-free training and/or improve torch.compile performance.
726765 """
727766
728767 residual = hidden_states
@@ -736,6 +775,7 @@ def forward(
736775 cache_params = past_key_value ,
737776 cache_position = cache_position ,
738777 attention_mask = attention_mask ,
778+ ** kwargs ,
739779 )
740780 self_attn_weights = None
741781 elif self .layer_type == "attention" :
@@ -838,7 +878,7 @@ def forward(
838878 output_attentions : Optional [bool ] = None ,
839879 output_hidden_states : Optional [bool ] = None ,
840880 cache_position : Optional [torch .LongTensor ] = None ,
841- ** kwargs , # NOOP kwargs, for now
881+ ** kwargs : Unpack [ BambaFlashAttentionKwargs ],
842882 ) -> BaseModelOutputWithPast :
843883 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
844884 output_hidden_states = (
@@ -890,7 +930,7 @@ def forward(
890930
891931 if self .gradient_checkpointing and self .training :
892932 layer_outputs = self ._gradient_checkpointing_func (
893- decoder_layer .__call__ ,
933+ partial ( decoder_layer .__call__ , ** kwargs ) ,
894934 hidden_states ,
895935 layer_mask ,
896936 position_ids ,
@@ -910,6 +950,7 @@ def forward(
910950 use_cache = use_cache ,
911951 cache_position = cache_position ,
912952 position_embeddings = position_embeddings ,
953+ ** kwargs ,
913954 )
914955
915956 hidden_states = layer_outputs [0 ]
0 commit comments