From 8b31706fc12c075abca3aca77ceb41aaf220d570 Mon Sep 17 00:00:00 2001 From: Rohit Dwivedula <25080952+rohitdwivedula@users.noreply.github.com> Date: Wed, 24 Jul 2024 02:09:59 -0500 Subject: [PATCH] adds: extra_repr() to MambaRMSNorm to include hidden size / size of weights in the layer (#32171) * adds: extra_repr() to MambaRMSNorm to include the hidden size of the layer * style fix with ruff: --- src/transformers/models/mamba/modeling_mamba.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 50c0f9ebe4a580..fb519bee3da03f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -327,6 +327,9 @@ def forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + def extra_repr(self): + return f"{self.weight.shape[0]}, eps={self.variance_epsilon}" + class MambaBlock(nn.Module): def __init__(self, config, layer_idx):