Skip to content

Conversation

@Birch-san
Copy link

What does this PR do?

image

before this PR: lm_head weights were initialized with variance of 1, and it output activations with variance ~= hidden_dim. this is a very high variance for logits, and resulted in initial cross-entropy loss of ~110, which is Very High.

after this PR: lm_head weights initialized with variance of reciprocal of hidden_dim. this outputs activations with variance ~= 1. this is results in initial cross-entropy loss of ~11, which is high, but in line with what we'd expect.

Fixes #16749 (again)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @younesbelkada

…dden dim.

before this PR: lm_head weights were initialized with variance of 1, and it output activations with variance ~= hidden_dim. this is a very high variance for logits, and resulted in initial cross-entropy loss of ~110, which is Very High.

after this PR: lm_head weights initialized with variance of reciprocal of hidden_dim. this outputs activatiosn with variance ~= 1. this is results in initial cross-entropy loss of ~11, which is high, but closer to what we'd expect.
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the deep investigation and making T5 models converge better ! Can you run the styling checks?

make fixup && make fix-copies

Then we can merge I think

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@abdulfatir
Copy link

I would like to argue for setting the initialization of both input and output embedding layers to normal(0, 0.05) as discussed in this comment: #16749 (comment). I am not sure why this was actually not done (cc @patrickvonplaten).

@abdulfatir
Copy link

abdulfatir commented Sep 30, 2023

To provide more context: neither the default initialization (normal(0, 1)) nor what's proposed in this PR works in my experiment when I train from random weights. normal(0, 0.05) works and should probably be the default to align with the original implementation.

@Birch-san
Copy link
Author

Birch-san commented Sep 30, 2023

I wasn't sure what to make of Patrick's stddev=0.05, as no explanation was given.

I surmise that the reasoning was:

  • VocabEmbedding#__init__ initializes embedding via the mtf.layers.embedding_weights helper.
    It might default to _scale_variable_like_classifier_weights=False (I didn't see any .gin file overriding this, but then again I never found the .gin file used to train T5), which would make initializer=None.
    • side-note: I didn't see any capability to support an untied lm_head. maybe that means we should default to using a tied lm_head? though for vocab as small as 32k, it's probably better to leave it untied.
  • when initializer is None, mtf.layers.embedding_weights initializes the embedding via tf.random_normal_initializer().
  • tf.random_normal_initializer() defaults to std=0.05.

That said, when hidden_dim==512 (as in t5-small): std=0.05 is awfully close to std=hidden_dim**-.5=0.044.

I tried initializing both the embedding and the lm_head with std=0.05. It performed moreorless the same (well, fractionally worse).

image

@abdulfatir
Copy link

Thanks for the quick turnaround. In my experiment:

  • Current default std=1: Starts at a high loss and converges to high value.
  • lm_head with std=hidden_dim**-.5: Starts at a lower loss than before but diverges after a few hundred steps.
  • initializer_factor=0.05: Starts at an even lower loss and converges nicely.

@Birch-san
Copy link
Author

Birch-san commented Sep 30, 2023

how big is your hidden_dim? for t5-small, 512**-.5=0.044 is a very similar number to 0.05.

@abdulfatir
Copy link

I am also running my quick tests on flan-t5-small.

@Birch-san
Copy link
Author

okay yeah certainly I agree that we should modify both the embedding and the lm_head weights.

I think we don't actually have any info from the mesh tensorflow repository on how to initialize an untied lm_head. I don't think they support that.

@Birch-san
Copy link
Author

I tried a model with a bigger hidden dim (t5 base, 768). so that there would be a bigger difference between 0.05 and hidden_dim**-.5.

image

no discernable difference by 600 steps.

so I don't know which is more correct out of dim**-.5 or 0.05, but I think empirically there's not a big difference.

certainly I think initializing the embedding to 0.05 matches the MTF repository.

but I don't think the MTF repository supports an untied head at all, so I think that's up to us? unless there's some precedent for what to do there. there's two schools of thought:

  • initialize the lm_head as though it were tied to the embedding (0.05)
  • initialize the lm_head like you'd initialize most Linear layers (dim**-.5)

I kinda think treating the lm_head as untied (dim**-.5) makes more sense.

numerically and empirically I think there's not much difference. I think the more important thing is don't initialize either the embedding or the lm_head with std=1.

@Birch-san
Copy link
Author

Birch-san commented Sep 30, 2023

oh, there is a reference to tied vs untied lm_head:
https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586-L594

I think they're saying "if the embedding is tied, we anticipate that the hidden states will have too much variance, and so we correct it". so whatever std they use to initialize the embedding, is higher than the std an lm_head would want? this feels a bit weird. it kinda sounds like the embedding was initialized with something much larger than 0.05, and they're trying to compensate. is this a clue as to why HF initializes embedding to std=1?

whereas if it's untied, they treat it as a dense layer with kernel_initializer=None:
https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L88-L90
so I think they initialize untied lm_head as std=hidden_dim**-.5?

@Birch-san
Copy link
Author

Birch-san commented Oct 1, 2023

maybe we can learn something about how to initialize the weights, based on how they ended up distributed after Google finished training the model?

I took an off-the-shelf model:

from transformers import T5ForConditionalGeneration
model: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained('google/t5-v1_1-small')

and I analysed how all of its layers were distributed.

embedding std is 11.6375
lm_head std is 1.1614

considering encoder self-attn, decoder self-attn, decoder cross-attn…
attn q is ~0.05 or ~hidden_dim**-.5 (hard to tell the difference)
attn k is 0.3~0.4
attn v is 0.6~0.8
attn o is 0.6~1.0

considering encoder FFN, decoder FFN…
FFN wi_0 is 0.32~0.36
FFN wi_1 is 0.85~1.17
FFN wo is 0.53~0.73

encoder per-block layernorms ~0.02
encoder final layernorm ~0.05 or ~hidden_dim**-.5 (hard to tell the difference)

decoder per-block layernorms 0.07~0.10
decoder final layernorm ~0.7

@Birch-san
Copy link
Author

Birch-san commented Oct 1, 2023

t5-small pretraining .gin config apparently lives here. I hadn't been able to find it in the MTF repository.
from allenai/unifiedqa#40 (comment).

@patrickvonplaten
Copy link
Contributor

The T5X repository has become somewhat of a reference implementation to pretrain T5X. Here they use stddev=1.0 as an initialization: https://github.com/google-research/t5x/blob/b051e46075fdcb02fcdd4dc648dd9560243bfdb2/t5x/examples/t5/network.py#L289

That aligns with what we currently have in Transformers. Should we maybe open an issue in T5X asking about the differences because it indeed seems like in the original implementation they used 0.05 as init.

@Birch-san
Copy link
Author

@patrickvonplaten

Here they use stddev=1.0 as an initialization

do they though? it looks to me like they use scale=1.0, not std=1.0.

@Birch-san
Copy link
Author

Birch-san commented Oct 2, 2023

@patrickvonplaten

Should we maybe open an issue in T5X asking about the differences because it indeed seems like in the original implementation they used 0.05 as init.

what did you think of my investigation that suggested that MTF initialized untied lm_head using std=hidden_dim**-.5?

@patrickvonplaten
Copy link
Contributor

The T5X repository has become somewhat of a reference implementation to pretrain T5X. Here they use stddev=1.0 as an initialization: https://github.com/google-research/t5x/blob/b051e46075fdcb02fcdd4dc648dd9560243bfdb2/t5x/examples/t5/network.py#L289

That aligns with what we currently have in Transformers. Should we maybe open an issue in T5X asking about the differences because it indeed seems like in the original implementation they used 0.05 as init.

Agree that there is a difference! Let's maybe open an issue in T5X and link it here? Think T5X is still pretty active

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Birch-san
Copy link
Author

google-research/t5x#1477

@patrickvonplaten
Copy link
Contributor

Thanks a lot for posting on T5X @Birch-san - 🤞 someone from the T5 team looks into it

@vadimkantorov
Copy link

vadimkantorov commented Dec 26, 2023

in the meanwhile, maybe worth un-staling/re-opening this PR?

@vadimkantorov
Copy link

Haha it appears the stale bot is over-aggressive

@huggingface huggingface deleted a comment from github-actions bot Jan 22, 2024
@patrickvonplaten patrickvonplaten added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jan 23, 2024
@dhruvbird
Copy link

I ran into exactly this issue today, and the proposed fixed made the grad_norm during training misbehave less (at least so far. Here are the charts for reference.

image

Purple is on master, and blue is with the fix in this MR.

It would be super if we can merge this PR into master so that anyone trying to train a T5 model from scratch (i.e. no fine-tuning using pre-trained weights) can benefit from this fix.

@amyeroberts
Copy link
Contributor

@dhruvbird Thanks for sharing your results! It appears this wasn't merged in as it was pending confirmation from the T5 team. Although this would technically be a breaking change, it does appear to be useful for training the model. WDYT about adding @ArthurZucker ?

@ArthurZucker
Copy link
Collaborator

Yeah sounds good in general, the issue is that it could affect the entire field of diffusion that relies on t5 cc @sayakpaul

@sayakpaul
Copy link
Member

In the diffusion community, T5 is not generally trained from scratch, though. Just a heads up.

@Birch-san
Copy link
Author

Birch-san commented Sep 6, 2024

I also note that attention is not being scaled by the usual head_dim**-.5:

) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

instead, this scale factor is fused into the weight init of the Q projection:

module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))

it's likely Google did this for performance reasons (I expect only marginal, and certainly outmoded by flash attention).
they made this decision knowing they were using Adafactor optimizer, which exhibits equivalent training dynamics either way.

under AdamW, this optimization is harmful. you can only do this trick if your optimizer scales lr by the RMS of the params (e.g. Adafactor, LARS+AdamW, or AdamWScaled).

so I also recommend to change the Q weight init to be the same as the K and V weight init (d_model**-.5), and to scale Q•K.T by head_dim**-.5 (i.e. typical attention formulation). upgrading to torch SDPA would of course be nice too.

@ArthurZucker
Copy link
Collaborator

#31167 is coming, I gotta review it, scaling would be breaking as it is not done currently

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Large differences between T5 weight initialization in TF and torch

10 participants