diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 50c0f9ebe4a580..fb519bee3da03f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -327,6 +327,9 @@ def forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + def extra_repr(self): + return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" + class MambaBlock(nn.Module): def __init__(self, config, layer_idx):