Skip to content

Add custom decompositions for cross entropy loss for the nvfuser executor #2043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
May 27, 2025

Conversation

protonu
Copy link
Collaborator

@protonu protonu commented May 6, 2025

This PR adds custom decompositions for Cross-Entropy Loss for the nvFuser executor.
Adding these custom decompositions improves performance and allows further optimization in nvFuser.

For cross-entropy loss forward:

  1. We move the take_along_axis computation before the log softmax is computed. This allows us to reduce memory traffic for the inputs.

For cross-entropy loss backward:

  1. We replace a scatter-op with a iota and where op as we don't have support for scatter exposed in nvFuser.
  2. We can get rid of a reduction that shows up when backward is computed as nll_loss backward followed by log softmax backward.

cc @tfogal

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

want to double check the forward implementation. kinda looks a bit strange to me.

@protonu protonu force-pushed the pbasu_loss_fwd_bwd branch from a60395a to b045ccb Compare May 15, 2025 22:19
@protonu protonu requested review from jjsjann123 and crcrpar May 15, 2025 22:20
protonu and others added 2 commits May 19, 2025 19:46
@protonu protonu requested a review from crcrpar May 19, 2025 23:47
Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wondering if you have some numbers of cross entropy loss to compare this with the existing ones e.g. https://github.com/Lightning-AI/lightning-thunder/blob/c6928015914fdbdd708fd8e87fbd9d9c1b4a40ef/thunder/executors/triton_crossentropy.py?

@protonu
Copy link
Collaborator Author

protonu commented May 21, 2025

just wondering if you have some numbers of cross entropy loss to compare this with the existing ones e.g. https://github.com/Lightning-AI/lightning-thunder/blob/c6928015914fdbdd708fd8e87fbd9d9c1b4a40ef/thunder/executors/triton_crossentropy.py?

@crcrpar I did compare performance against torchcompile (which uses Triton - but is that the same as the link you sent?)
For performance I used these changes in this PR and a pre-segmentation pass in nvfuser (NVIDIA/Fuser#4399).

This is a benchmark I was going to add:
NVIDIA/Fuser#4472

@mruberry
Copy link
Collaborator

@IvanYashchuk , @jjsjann123 would you like to review again?

@mruberry mruberry requested a review from beverlylytle May 22, 2025 19:12
@mruberry
Copy link
Collaborator

@beverlylytle Would you like to take a look, too?

Copy link
Collaborator

@beverlylytle beverlylytle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

protonu and others added 3 commits May 23, 2025 11:32
Co-authored-by: beverlylytle <57254617+beverlylytle@users.noreply.github.com>
Co-authored-by: beverlylytle <57254617+beverlylytle@users.noreply.github.com>
@IvanYashchuk IvanYashchuk enabled auto-merge (squash) May 27, 2025 15:58
@mruberry
Copy link
Collaborator

@IvanYashchuk did you want to take a last look?

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

@IvanYashchuk IvanYashchuk merged commit 9a7355d into main May 27, 2025
49 checks passed
@IvanYashchuk IvanYashchuk deleted the pbasu_loss_fwd_bwd branch May 27, 2025 20:13
KaelanDt pushed a commit that referenced this pull request May 29, 2025
…utor (#2043)

This PR adds custom decompositions for Cross-Entropy Loss for the nvFuser executor. 
Adding these custom decompositions improves performance and allows further optimization in nvFuser.

For cross-entropy loss forward:

1. We move the take_along_axis computation before the log softmax is computed. This allows us to reduce memory traffic for the inputs.

For cross-entropy loss backward:

1. We replace a scatter-op with a iota and where op as we don't have support for scatter exposed in nvFuser.
2.  We can get rid of a reduction that shows up when backward is computed as nll_loss backward followed by log softmax backward.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants