Skip to content

Commit

Permalink
adds: extra_repr() to MambaRMSNorm to include hidden size / size of w…
Browse files Browse the repository at this point in the history
…eights in the layer (#32171)

* adds: extra_repr() to MambaRMSNorm to include the hidden size of the layer

* style fix with ruff:
  • Loading branch information
rohitdwivedula authored and itazap committed Jul 25, 2024
1 parent ba0a2f1 commit 1c61cbc
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def forward(self, hidden_states):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"


class MambaBlock(nn.Module):
def __init__(self, config, layer_idx):
Expand Down

0 comments on commit 1c61cbc

Please sign in to comment.