Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Gemma RMSNorm #85

Merged
merged 2 commits into from
Aug 26, 2024
Merged

Fix Gemma RMSNorm #85

merged 2 commits into from
Aug 26, 2024

Conversation

davidgonmar
Copy link
Contributor

@davidgonmar davidgonmar commented Aug 25, 2024

Summary

Fixes #74. Allows Gemma's RMSNorm by adding a generic offset parameter to RMSNorm.

Testing Done

Parametrize the RMSNorm tests to test both Llama's and Gemma's versions (as per Hugging Face's transformers library).

  • Hardware Type: tested on NVIDIA L4/L40
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Test outputs

test/transformers/test_rms_norm.py ................................                                                                                                                                                                                            [100%]

========================================================================================================================= 32 passed in 3.74s =========================================================================================================================

test/convergence/test_mini_models.py ......                                                                                                                                                                                                                    [100%]

========================================================================================================================= 6 passed in 32.04s =========================================================================================================================

@davidgonmar davidgonmar reopened this Aug 25, 2024
@davidgonmar davidgonmar marked this pull request as ready for review August 25, 2024 17:34
Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Our first kernel related PR :-) Left a comment

return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x.float())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yundai424 I am curious if we need to do the upcasting too in our kernel? or we don't because anyway it gets casted to x.dtype. Our old impl doesn't have that but the result is still consistent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yundai424 I am curious if we need to do the upcasting too in our kernel? or we don't because anyway it gets casted to x.dtype. Our old impl doesn't have that but the result is still consistent

The test results seem to be consistent with the reference too. I think it's a tradeoff between complexity (because some models have slight inconsistencies) and exact reproduction. Let me know what you think and I can modify it (or even on another PR, to keep things clean)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! waiting @yundai424 to final check

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add gemma to convergence test as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add gemma to convergence test as well?

adding it

Copy link
Contributor Author

@davidgonmar davidgonmar Aug 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add gemma to convergence test as well?

done. as mentioned in the discord channel, I made the tolerance for bfloat16 convergence test a bit bigger, since the casting difference seems to be slightly affecting it (for the regular tests it's fine, but it seems to add up in the end-to-end training situation). if we end up deciding to match it one-to-one in the future, we can make it stricter. let me know what you think

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my read is yes we need to change the dtype to x to fp32. Regardless of if it's llama or gemma they both do the norm part in fp32, only difference is gemma doing the scaling in fp32 too while llama do it in mixed precision. But of course we can do it in a separate PR 🤔

@ByronHsu
Copy link
Collaborator

ByronHsu commented Aug 25, 2024

May you join our discord and say hi? The PR looks very solid and we would love more contribution from you https://discord.gg/CX2YmNmn

@ByronHsu
Copy link
Collaborator

Also can you add your hardware type in the PR description? I just updated PR template

@davidgonmar davidgonmar changed the title Allow Gemma RMSNorm Fix Gemma RMSNorm Aug 25, 2024
Copy link
Collaborator

@yundai424 yundai424 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the contribution! Let's create a follow-up issue for the dtype thing

@lancerts lancerts merged commit a8e433b into linkedin:main Aug 26, 2024
1 check passed
@tyler-romero tyler-romero mentioned this pull request Aug 26, 2024
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[bug fix] Gemma needs RMSNorm with 1.0 offset
5 participants