-
Notifications
You must be signed in to change notification settings - Fork 31.3k
[T5] lm_head weights initialization: set variance to reciprocal of hidden dim #26441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…dden dim. before this PR: lm_head weights were initialized with variance of 1, and it output activations with variance ~= hidden_dim. this is a very high variance for logits, and resulted in initial cross-entropy loss of ~110, which is Very High. after this PR: lm_head weights initialized with variance of reciprocal of hidden_dim. this outputs activatiosn with variance ~= 1. this is results in initial cross-entropy loss of ~11, which is high, but closer to what we'd expect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the deep investigation and making T5 models converge better ! Can you run the styling checks?
make fixup && make fix-copiesThen we can merge I think
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
I would like to argue for setting the initialization of both input and output embedding layers to |
|
To provide more context: neither the default initialization ( |
|
I wasn't sure what to make of Patrick's I surmise that the reasoning was:
That said, when I tried initializing both the embedding and the lm_head with std=0.05. It performed moreorless the same (well, fractionally worse).
|
|
Thanks for the quick turnaround. In my experiment:
|
|
how big is your hidden_dim? for t5-small, |
|
I am also running my quick tests on |
|
okay yeah certainly I agree that we should modify both the embedding and the lm_head weights. I think we don't actually have any info from the mesh tensorflow repository on how to initialize an untied |
|
oh, there is a reference to tied vs untied lm_head: I think they're saying "if the embedding is tied, we anticipate that the hidden states will have too much variance, and so we correct it". so whatever std they use to initialize the embedding, is higher than the std an lm_head would want? this feels a bit weird. it kinda sounds like the embedding was initialized with something much larger than 0.05, and they're trying to compensate. is this a clue as to why HF initializes embedding to std=1? whereas if it's untied, they treat it as a |
|
maybe we can learn something about how to initialize the weights, based on how they ended up distributed after Google finished training the model? I took an off-the-shelf model: from transformers import T5ForConditionalGeneration
model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained('google/t5-v1_1-small')and I analysed how all of its layers were distributed. embedding std is considering encoder self-attn, decoder self-attn, decoder cross-attn… considering encoder FFN, decoder FFN… encoder per-block layernorms decoder per-block layernorms |
|
t5-small pretraining |
|
The T5X repository has become somewhat of a reference implementation to pretrain T5X. Here they use stddev=1.0 as an initialization: https://github.com/google-research/t5x/blob/b051e46075fdcb02fcdd4dc648dd9560243bfdb2/t5x/examples/t5/network.py#L289 That aligns with what we currently have in Transformers. Should we maybe open an issue in T5X asking about the differences because it indeed seems like in the original implementation they used 0.05 as init. |
do they though? it looks to me like they use
|
what did you think of my investigation that suggested that MTF initialized untied lm_head using |
Agree that there is a difference! Let's maybe open an issue in T5X and link it here? Think T5X is still pretty active |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Thanks a lot for posting on T5X @Birch-san - 🤞 someone from the T5 team looks into it |
|
in the meanwhile, maybe worth un-staling/re-opening this PR? |
|
Haha it appears the stale bot is over-aggressive |
|
@dhruvbird Thanks for sharing your results! It appears this wasn't merged in as it was pending confirmation from the T5 team. Although this would technically be a breaking change, it does appear to be useful for training the model. WDYT about adding @ArthurZucker ? |
|
Yeah sounds good in general, the issue is that it could affect the entire field of diffusion that relies on t5 cc @sayakpaul |
|
In the diffusion community, T5 is not generally trained from scratch, though. Just a heads up. |
|
I also note that attention is not being scaled by the usual
instead, this scale factor is fused into the weight init of the Q projection:
it's likely Google did this for performance reasons (I expect only marginal, and certainly outmoded by flash attention). under AdamW, this optimization is harmful. you can only do this trick if your optimizer scales lr by the RMS of the params (e.g. Adafactor, LARS+AdamW, or AdamWScaled). so I also recommend to change the Q weight init to be the same as the K and V weight init ( |
|
#31167 is coming, I gotta review it, scaling would be breaking as it is not done currently |



What does this PR do?
before this PR: lm_head weights were initialized with variance of 1, and it output activations with variance ~= hidden_dim. this is a very high variance for logits, and resulted in initial cross-entropy loss of ~110, which is Very High.
after this PR: lm_head weights initialized with variance of reciprocal of hidden_dim. this outputs activations with variance ~= 1. this is results in initial cross-entropy loss of ~11, which is high, but in line with what we'd expect.
Fixes #16749 (again)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker @younesbelkada