Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gemma2 accuracy through the correct softcapping logic #2842

Merged
merged 2 commits into from
Dec 2, 2024

Conversation

AllentDan
Copy link
Collaborator

NOTE
There is still another place that is not aligned with gemma2, but I don't think that is important.
Gemma2 implementation of RMS Norm:

class DefaultRMSNormImpl(RMSNormImpl):
    """RMS norm implementation api."""

    def __init__(self, hidden_size: int, eps: float = 1e-6):
        self.hidden_size = hidden_size
        self.eps = eps

    def forward(self,
                x: torch.Tensor,
                weight: torch.Tensor,
                residual: torch.Tensor = None):
        """forward."""
        input_dtype = x.dtype
        if residual is not None:
            x = x + residual
            residual = x
        x = x.to(torch.float32)
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.eps)
        x = (weight.float() * x).to(input_dtype)
        if residual is None:
            return x
        return x, residual

Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
See huggingface/transformers#29402

@lvhan028 lvhan028 requested a review from grimoire December 2, 2024 05:33
@lvhan028 lvhan028 added the Bug:P1 label Dec 2, 2024
@lvhan028 lvhan028 merged commit b91ce9a into InternLM:main Dec 2, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants