Skip to content

Add feature ligerceloss #2741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

mananchawla2005
Copy link

PR: Add LigerFusedCrossEntropyLoss

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other

Closes #2692

Changelog

  • Added new LigerFusedCrossEntropyLoss class that provides memory-efficient cross entropy loss using fused CUDA kernels using liger-kernels
  • Implemented proper handling of distributed tensors (DTensor) with gradient hooks
  • Added comprehensive unit tests comparing against standard PyTorch cross entropy
  • Added docstrings with usage examples

Test plan

  • Added unit tests in test_liger_ce_loss.py comparing against F.cross_entropy
  • Tests verify correctness with:
    • Regular tensor inputs
    • Batched inputs requiring reshaping
    • Ignored indices
  • Added proper docstrings with usage examples
  • Tests pass locally on CUDA device
  • Pre-commit hooks and linters pass

UX

Example usage in docstring:

# Initialize model and loss
model = Transformer(...)  # model with skip_output_layer=True
loss = LigerFusedCrossEntropyLoss()
loss.set_model_output(model)  # This captures model's output layer

# Forward pass
hidden_states = model(inputs)  # [batch_size, seq_len, hidden_dim]
targets = labels  # [batch_size, seq_len]

# If needed, reshape to [batch_size*seq_len, hidden_dim]
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
targets = targets.view(-1)

loss_value = loss(hidden_states, targets)

The implementation provides better performance and memory efficiency compared to the chunked LinearCrossEntropyLoss by:

  1. Using fused CUDA kernels
  2. Combining linear projection and cross entropy in a single operation
  3. Properly handling distributed tensors

All tests verify numerical correctness against PyTorch's native cross entropy within expected tolerances.


Hi PyTorch Team,

This is my first PR to a machine learning project, and I’ve tried to ensure the code is correct and well-structured. I’ve implemented the functionality for LigerCeLoss and included a test case that verifies its behavior with both masked and reshaped inputs.

Due to hardware limitations, I wasn't able to fully validate the distributed functionality or run the entire test suite across multiple GPUs. However, I’ve implemented support for DTensor-based weights, hidden states, and targets, and included the logic for gather → fuse → scatter to handle gradient computation for sharded weights in distributed settings.

Looking forward to your feedback! @pbontrager @joecummings

Copy link

pytorch-bot bot commented May 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2741

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @mananchawla2005!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@mananchawla2005
Copy link
Author

CLA filled!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 16, 2025
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this contribution! Really clean first pass at the problem. I left some comments around testing and questions about how you're handling DTensors. For linting, you can checkout our contributing guide on how to setup precommit hooks.

# Validate the results are close enough
assert_expected(fused_loss, standard_loss, rtol=1e-2, atol=1e-2)

def test_liger_fused_cross_entropy_loss_with_reshape(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our SFTLoss type we can assume the input is "[bsz, seq_len, emb_dim]", so I don't think we need this second test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have a second test, but it should be a distributed test. Same as the first test but with 4 gpus required and FSDP size 2 and TP size 2. If you need help on how to initialize the model that way I can give you the code.



class TestLigerFusedCrossEntropyLoss:
def test_liger_fused_cross_entropy_loss(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this test requires cuda you should add the "@gpu_test(gpu_count=1)" decorator from "from tests.test_utils import gpu_test". Along with testing the loss value, I think it would be good to test a single forward and backward pass with opt step to ensure all the gradients are propagating back correctly too. You can use "fixed_init_model" (also from test_utils) as well to make it easier to initialize the model the same way each time

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to add @pytest.mark.parametrize("compile", [False, True]) to the test and pass in compile as an argument on whether to call apply_compile_strategy on the loss

orig_w = self.linear_projection.weight
if isinstance(orig_w, DTensor):
mesh, placements = orig_w.device_mesh, orig_w.placements
w = orig_w.full_tensor().detach().clone().requires_grad_(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you tell me more about what you're doing here? Does the liger loss require you to detach the weight? Why detach it only to manually register that gradients get reapplied? Also, I don't think you'd want to register a hook every forward pass.

@codecov-commenter
Copy link

codecov-commenter commented May 16, 2025

Codecov Report

Attention: Patch coverage is 86.66667% with 12 lines in your changes missing coverage. Please review.

Project coverage is 62.74%. Comparing base (c8e670b) to head (b4bf352).
Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/modules/loss/cross_entropy_loss.py 62.50% 12 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2741      +/-   ##
==========================================
+ Coverage   60.64%   62.74%   +2.09%     
==========================================
  Files         428      431       +3     
  Lines       26091    26479     +388     
==========================================
+ Hits        15823    16613     +790     
+ Misses      10268     9866     -402     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@mananchawla2005
Copy link
Author

@pbontrager Hey thanks for an indepth review! I have tried to resolve most of the issues that were coming earlier however for the distributed test would love if you can help a little with the code as well as testing it cause I dont have a distributed gpu setup.

# self.forward = torch.compile(
# self.forward, *args, **kwargs
# )
return self

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to compile a liger kernel at all?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey it was added in response to @pbontrager

I think it would be good to add @pytest.mark.parametrize("compile", [False, True]) to the test and pass in compile as an argument on whether to call apply_compile_strategy on the loss

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, ok, now I see 🙂

So the problem is that the kernel is automatically compiled.
It was obvious from the README of their repo - there was no mentioning that a user needs to call torch.compile manually, which leaves only one option. Yes, Liger provides custom optimized triton kernels, but without compilation they won't work.

So, after digging a bit of their codebase, here is how it works:

  1. LigerFusedLinearCrossEntropyFunction has custom forward and backward implementations. (Let's focus on forward variant for now.)
  2. Inside it, a fused_linear_cross_entropy_forward is called ...
  3. ... which calls a triton kernel, that has triton.jit wrapper.

Of course, you can control even that with

with torch._dynamo.disable():

and a flag that is disabled by default and enabled in apply_compile_strategy, but since Triton kernel code is not valid Python code for direct execution on a GPU without a compilation, there is no much sense in it.
Perhaps maybe control compilation of service code inside a loss code around this kernel, but I believe it won't make that much difference 🤔.

In other words, since it was never intended to use liger without a compilation, perhaps just skip this method without any warnings?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the warning and added a doc string that this is JIT compiled

try:
import liger_kernel.ops.fused_linear_cross_entropy

self.fused_linear_ce = liger_kernel.ops.fused_linear_cross_entropy

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dummy question: why did you deciced to go this route instead of a loss class as described in README: https://github.com/linkedin/Liger-Kernel?tab=readme-ov-file#3-compose-your-own-model

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, you are right to point out that! I chose the lower-level ops approach for bias handling as one under transformers don't have a option for bias parameter and during distributed training we need to handle DTensor. If thats not required I can replace it with the one in readme.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it looks like forward method of the class also accepts bias, which by default is None.

Basically, there is nothing wrong with your approach, just the one with the class looks, at least to me, slightly cleaner.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the class version was just a thin wrapper around LigerFusedLinearCrossEntropyFunction, and our loss class is also a thin wrapper around the same functionality, it feels more right to me that we just directly call LigerFusedLinearCrossEntropyFunction and operate at the same abstraction level as the nn.Module you linked to.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, fair enough.
Linear loss calls F.cross_entropy and this one calls a function - a lil bit of uniformity.

@joecummings joecummings mentioned this pull request Mar 30, 2025
4 tasks
@intervitens
Copy link
Contributor

I think it makes more sense to reshape hidden_states and targets to [batch_size*seq_len] inside the loss forward and to set reduction to mean instead of sum in order to match the behavior of existing LinearCrossEntropyLoss and allow the new LigerLinearCrossEntropy to be used as a drop-in replacement in recipe configs.

I've tested a distributed LoRA finetune of llama3.1-8B with those changes, and it seems to work fine, amount of reserved memory was reduced and the difference in loss was minimal.
https://wandb.ai/intervitens/8B-lora-dist-liger

@mananchawla2005
Copy link
Author

@intervitens Hey thanks for checking out the distributed training! Glad that it works well. I have incorporated your changes. It would also be very helpful if you can provide an initial starting for adding the distributed test.

Comment on lines 232 to 236
batch_size, seq_len, emb_dim = hidden_states.shape
hidden_states = hidden_states.reshape(
-1, emb_dim
) # [batch_size*seq_len, emb_dim]
targets = targets.reshape(-1) # [batch_size*seq_len]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like since you don't reuse B, T, C values anywhere, it could be done simpler:

Suggested change
batch_size, seq_len, emb_dim = hidden_states.shape
hidden_states = hidden_states.reshape(
-1, emb_dim
) # [batch_size*seq_len, emb_dim]
targets = targets.reshape(-1) # [batch_size*seq_len]
hidden_states = hidden_states.flatten(0, 1) # (batch_size*seq_len, hidden_size)
targets = targets.flatten() # (batch_size*seq_len)

None, # softcap
False, # return_z_loss
)
if total_elements == 0:
return loss
return loss

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, basically, return loss regardless? 🙂

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for doing all the updates! I think it's close now. I'm going to help with the unit tests and then once you have a chance to make any changes based on Andrei's comments, we should be good to land.

pyproject.toml Outdated
@@ -61,6 +62,11 @@ dev = [
"urllib3<2.0.0",
"wandb",
"expecttest",
# Triton:
"triton>=2.3.1 ; platform_system != 'Windows'",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this dependency explicitly needed? Pytorch already includes triton I believe.

b.register_hook(_scatter_b)
self._b_hook_registered = True

loss, _ = self.fused_linear_ce.LigerFusedLinearCrossEntropyFunction.apply(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: couldn't self.fused_linear_ce = LigerFusedLinearCrossEntropyFunction in your init and then here you'd just have self.fused_linear_ce.apply(...)?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more nit: could we simplify the arguments list, since most of the values that are provided are actually equal to the default ones?

class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
    @staticmethod
    @amp_custom_fwd
    def forward(
        ctx,
        _input,
        weight,
        target,
        bias=None,
        ce_weight=None,
        ignore_index=-100,
        lse_square_scale=0.0,
        label_smoothing=0.0,
        reduction="mean",
        softcap=None,
        return_z_loss: bool = False,
    ):

if isinstance(w, DTensor):
mesh, placements = w.device_mesh, w.placements
w = w.full_tensor()
if not hasattr(self, "_w_hook_registered"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think "full_tensor" handles gradient placement and we don't need to do this ourselves link. I can test removing this though.

try:
import liger_kernel.ops.fused_linear_cross_entropy

self.fused_linear_ce = liger_kernel.ops.fused_linear_cross_entropy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the class version was just a thin wrapper around LigerFusedLinearCrossEntropyFunction, and our loss class is also a thin wrapper around the same functionality, it feels more right to me that we just directly call LigerFusedLinearCrossEntropyFunction and operate at the same abstraction level as the nn.Module you linked to.

from torchtune.training.seed import set_seed


@gpu_test(gpu_count=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me edit and push some changes to these tests. I can run them myself then for you.

@mananchawla2005
Copy link
Author

@pbontrager I have simplifed distributed handling as well done the refactoring as per suggested changes. Would be greatful if you can help me add tests for distributed.

@mananchawla2005
Copy link
Author

Hey any updates?

@pbontrager
Copy link
Contributor

@mananchawla2005 I've gotten the unit tests working to test FSDP + TP for a single training step with the Liger loss. I haven't pushed the changes to your PR yet because the DTensor backward hook doesn't seem to work correctly so I'm trying to fix that and get the test to pass. I should be able to get back to this and get something to you by the end of this week.

@pbontrager
Copy link
Contributor

I pushed the tests with some changes here but the tests aren't passing in the distributed case. I've tried playing around with getting the full tensor in a full TP setting and full FSDP setting (the default test is a mix of both) but I'm still getting numerical differences. @ebsmothers do you have any ideas here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Add Liger CE Loss
6 participants