Skip to content

Commit

Permalink
adds: extra_repr() to MambaRMSNorm to include hidden size / size of w…
Browse files Browse the repository at this point in the history
…eights in the layer (huggingface#32171)

* adds: extra_repr() to MambaRMSNorm to include the hidden size of the layer

* style fix with ruff:
  • Loading branch information
rohitdwivedula authored and zucchini-nlp committed Jul 24, 2024
1 parent 9a35400 commit 8b31706
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def forward(self, hidden_states):
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)

def extra_repr(self):
return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"


class MambaBlock(nn.Module):
def __init__(self, config, layer_idx):
Expand Down

0 comments on commit 8b31706

Please sign in to comment.