-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformer building blocks tutorial #3075
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/3075
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New FailuresAs of commit c28f568 with merge base 4ed884d (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(discussed offline) more thorough review coming when the rendered docs are available. Sadly, we require nightly for the NJT stuff to work.
If you are looking for an out-of-the-box implementation of a popular transformer | ||
architecture, note that there are many open-source libraries that provide them, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: mention / link xFormers here as well :)
# is actually the same as an ``nn.TransformerEncoderLayer`` with ``is_causal=True``. | ||
|
||
# We demonstrate examples of implementing the rest of the nn layers | ||
# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume the name / hosting for this repo will change before publishing the tutorial?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really good! Love how it reads, and I think you did a good job introducing the primitives and covering the background.
I left a bunch of silly editorial nits but nothing too major.
I guess we still need some flex + NJT demonstration once #136792 lands?
|
||
.. customcarditem:: | ||
:header: [Title TBD] Unbundling nn.Transformer modules for gains and profits | ||
:card_description: This tutorial goes over recommended best practices for implementing Transformers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:card_description: This tutorial goes over recommended best practices for implementing Transformers. | |
:card_description: This tutorial goes over recommended best practices for implementing Transformers with native PyTorch. |
# One example of this is in ``nn.TransformerDecoderLayer`` where the query comes | ||
# from the decoder and the key/value come from the encoder. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this actually work with NJTs? lol I'm surprised you're not running into problems with having different nested ints for query / key+value
# There has been a long standing issue with ``nn.MultiheadAttention`` and | ||
# ``scaled_dot_product_attention`` where if a row was fully masked, the output | ||
# of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_. | ||
# This is because the softmax operation would divide by zero. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'm being super pedantic here, but I think the NaNs actually occurred before division due to the special trick softmax uses for numerical stability (subtracting the max).
maybe it's better stay at the conceptual level and mention softmax over an empty set being undefined?
# of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_. | ||
# This is because the softmax operation would divide by zero. | ||
# | ||
# Thanks to `this PR <https://github.com/pytorch/pytorch/pull/133882>`_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'm not sure it's within scope of a tutorial to mention specific PRs, but I think it is valuable to say that rolling a custom MHA doesn't run into the same NaN issues as the old nn.MHA because we're not employing a fused kernel with this problem (i.e. the fastpath case, which I think still exhibits the NaN behavior even after @drisspg's fix)
If you wanted, you could mention that NJT's ability to model raggedness appropriately makes it possible to distinguish when there is an empty sequence
Description
This adds the tutorial for transformer building blocks following the outline discussed in nn/optim triage on Friday (9/27/24) here https://docs.google.com/document/d/1TMrd0bDiM9-lcFHi079edkMRP1Ux5MTxt4lI1diiAKI/edit
This tutorial also links to a repo https://github.com/mikaylagawarecki/temp which
nn.Transformer
-related layers in pytorch in a NJT friendly manner (basically no more*_padding_mask
)To run this tutorial with correctness, we likely need torch 2.6
There are a few pending sections in this tutorial that hope to demonstrate more cool examples of composing feature with NJT that are pending some PRs. Not sure whether we should consider this a v0 and add those as v1?
index_put_
(KV caching section) Add support for index_put_ in NT pytorch#135722FlexAttention
+ NJT FlexAttention support for NJT pytorch#136792Checklist