-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: layernorm rms replacement for T5 #107
Conversation
can you check if there is an error in the reference implementation? |
IMO needs full test and benchmark comparaison on 3090. I will post a10g |
a10g Without replacement
|
a10g with rms
|
a10g BERT feat/rms-replacement branch for regression
|
a10g BERT current main branch for regression
|
test pass
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please fix imports
src/kernl/optimizer/layer_norm.py
Outdated
@@ -15,6 +15,8 @@ | |||
|
|||
import torch | |||
|
|||
from src.kernl.implementations.layer_norm import _layer_norm_fwd_fused_single_pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove src.
@@ -15,6 +15,8 @@ | |||
|
|||
import torch | |||
|
|||
from src.kernl.optimizer.layer_norm import replace_layer_norm_rms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove src.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
checked good speedup e2e
This PR requires full test run because we modify replacement