Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer building blocks tutorial #3075

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

mikaylagawarecki
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki commented Oct 4, 2024

Description

This adds the tutorial for transformer building blocks following the outline discussed in nn/optim triage on Friday (9/27/24) here https://docs.google.com/document/d/1TMrd0bDiM9-lcFHi079edkMRP1Ux5MTxt4lI1diiAKI/edit

This tutorial also links to a repo https://github.com/mikaylagawarecki/temp which

  • has examples of implementing the rest of the nn.Transformer-related layers in pytorch in a NJT friendly manner (basically no more *_padding_mask)
  • Notes some cases that we don't intend to demonstrate (e.g. see here)
  • removes fast path logic from MHA/TEL/TE
  • sanity checks that for MHA/TEL/TDL over kwargs: new_layer + NJT + compile we have correctness + perf gains over nn.layer + dense + mask + compile (as we expect :)). (TE, TD and T are just higher level wrappers so we didn't test those)

To run this tutorial with correctness, we likely need torch 2.6

There are a few pending sections in this tutorial that hope to demonstrate more cool examples of composing feature with NJT that are pending some PRs. Not sure whether we should consider this a v0 and add those as v1?

Checklist

  • The issue that is being fixed is referred in the description (see above "Fixes #ISSUE_NUMBER")
  • Only one issue is addressed in this pull request
  • Labels from the issue that this PR is fixing are added to this pull request
  • No unnecessary issues are included into this pull request.

Copy link

pytorch-bot bot commented Oct 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/3075

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit c28f568 with merge base 4ed884d (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(discussed offline) more thorough review coming when the rendered docs are available. Sadly, we require nightly for the NJT stuff to work.

Comment on lines +32 to +33
If you are looking for an out-of-the-box implementation of a popular transformer
architecture, note that there are many open-source libraries that provide them,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: mention / link xFormers here as well :)

# is actually the same as an ``nn.TransformerEncoderLayer`` with ``is_causal=True``.

# We demonstrate examples of implementing the rest of the nn layers
# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume the name / hosting for this repo will change before publishing the tutorial?

Copy link
Contributor

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really good! Love how it reads, and I think you did a good job introducing the primitives and covering the background.

I left a bunch of silly editorial nits but nothing too major.

I guess we still need some flex + NJT demonstration once #136792 lands?


.. customcarditem::
:header: [Title TBD] Unbundling nn.Transformer modules for gains and profits
:card_description: This tutorial goes over recommended best practices for implementing Transformers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:card_description: This tutorial goes over recommended best practices for implementing Transformers.
:card_description: This tutorial goes over recommended best practices for implementing Transformers with native PyTorch.

Comment on lines +560 to +561
# One example of this is in ``nn.TransformerDecoderLayer`` where the query comes
# from the decoder and the key/value come from the encoder.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this actually work with NJTs? lol I'm surprised you're not running into problems with having different nested ints for query / key+value

# There has been a long standing issue with ``nn.MultiheadAttention`` and
# ``scaled_dot_product_attention`` where if a row was fully masked, the output
# of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_.
# This is because the softmax operation would divide by zero.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'm being super pedantic here, but I think the NaNs actually occurred before division due to the special trick softmax uses for numerical stability (subtracting the max).

maybe it's better stay at the conceptual level and mention softmax over an empty set being undefined?

# of the attention layer would be NaN. See `issue <https://github.com/pytorch/pytorch/issues/41508>`_.
# This is because the softmax operation would divide by zero.
#
# Thanks to `this PR <https://github.com/pytorch/pytorch/pull/133882>`_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'm not sure it's within scope of a tutorial to mention specific PRs, but I think it is valuable to say that rolling a custom MHA doesn't run into the same NaN issues as the old nn.MHA because we're not employing a fused kernel with this problem (i.e. the fastpath case, which I think still exhibits the NaN behavior even after @drisspg's fix)

If you wanted, you could mention that NJT's ability to model raggedness appropriately makes it possible to distinguish when there is an empty sequence

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants