-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Gemma RMSNorm #85
Conversation
49c8823
to
52d317a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Our first kernel related PR :-) Left a comment
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | ||
|
||
def forward(self, x): | ||
output = self._norm(x.float()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yundai424 I am curious if we need to do the upcasting too in our kernel? or we don't because anyway it gets casted to x.dtype. Our old impl doesn't have that but the result is still consistent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yundai424 I am curious if we need to do the upcasting too in our kernel? or we don't because anyway it gets casted to x.dtype. Our old impl doesn't have that but the result is still consistent
The test results seem to be consistent with the reference too. I think it's a tradeoff between complexity (because some models have slight inconsistencies) and exact reproduction. Let me know what you think and I can modify it (or even on another PR, to keep things clean)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! waiting @yundai424 to final check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add gemma to convergence test as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add gemma to convergence test as well?
adding it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add gemma to convergence test as well?
done. as mentioned in the discord channel, I made the tolerance for bfloat16 convergence test a bit bigger, since the casting difference seems to be slightly affecting it (for the regular tests it's fine, but it seems to add up in the end-to-end training situation). if we end up deciding to match it one-to-one in the future, we can make it stricter. let me know what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my read is yes we need to change the dtype to x
to fp32. Regardless of if it's llama or gemma they both do the norm part in fp32, only difference is gemma doing the scaling in fp32 too while llama do it in mixed precision. But of course we can do it in a separate PR 🤔
May you join our discord and say hi? The PR looks very solid and we would love more contribution from you https://discord.gg/CX2YmNmn |
Also can you add your hardware type in the PR description? I just updated PR template |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the contribution! Let's create a follow-up issue for the dtype thing
Summary
Fixes #74. Allows Gemma's RMSNorm by adding a generic offset parameter to RMSNorm.
Testing Done
Parametrize the RMSNorm tests to test both Llama's and Gemma's versions (as per Hugging Face's transformers library).
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergenceTest outputs