diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 460c1f3b32acbf..89f08dd3cd3276 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -79,6 +79,8 @@ class MambaConfig(PretrainedConfig): Whether or not to rescale `out_proj` weights when initializing. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. + use_mambapy (`bool`, *optional*, defaults to `False`): + Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not avaiable. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. Example: @@ -123,6 +125,7 @@ def __init__( time_step_floor=1e-4, rescale_prenorm_residual=False, use_cache=True, + use_mambapy=False, **kwargs, ): self.vocab_size = vocab_size @@ -149,5 +152,6 @@ def __init__( self.rescale_prenorm_residual = rescale_prenorm_residual self.residual_in_fp32 = residual_in_fp32 self.use_cache = use_cache + self.use_mambapy = use_mambapy super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 5edb28ad7416e3..50c0f9ebe4a580 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -33,12 +33,17 @@ add_start_docstrings_to_model_forward, logging, ) -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available from .configuration_mamba import MambaConfig logger = logging.get_logger(__name__) +if is_mambapy_available(): + from mambapy.pscan import pscan +else: + pscan = None + if is_mamba_ssm_available(): from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update @@ -87,6 +92,8 @@ def __init__(self, config: MambaConfig, layer_idx: int): self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] + self.use_mambapy = config.use_mambapy + # projection of the input hidden states self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) # selective projection used to make dt, B and C input dependant @@ -105,11 +112,23 @@ def __init__(self, config: MambaConfig, layer_idx: int): self.use_bias = config.use_bias if not is_fast_path_available: - logger.warning_once( - "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d" - ) + if self.use_mambapy: + if is_mambapy_available(): + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + else: + raise ImportError( + "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py." + ) + else: + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." + ) def cuda_kernels_forward( self, @@ -257,17 +276,24 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) - scan_outputs = [] - for i in range(seq_len): - ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state] - scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1] - scan_outputs.append(scan_output[:, :, 0]) - scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] - scan_output = scan_output + (hidden_states * self.D[None, :, None]) - scan_output = (scan_output * self.act(gate)) + if self.use_mambapy and self.training and cache_params is None: + hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size] - if cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] + scan_output = scan_output + hidden_states * self.D[None, :, None] + scan_output = scan_output * self.act(gate) + else: + scan_outputs = [] + for i in range(seq_len): + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] + scan_outputs.append(scan_output[:, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = (scan_output * self.act(gate)) + + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index a5ea4eb1850c57..c52da62c1de8e5 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -395,6 +395,12 @@ def is_causal_conv1d_available(): return False +def is_mambapy_available(): + if is_torch_available(): + return _is_package_available("mambapy") + return False + + def is_torch_mps_available(): if is_torch_available(): import torch