-
Notifications
You must be signed in to change notification settings - Fork 620
Add feature ligerceloss #2741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add feature ligerceloss #2741
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2741
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @mananchawla2005! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
CLA filled! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this contribution! Really clean first pass at the problem. I left some comments around testing and questions about how you're handling DTensors. For linting, you can checkout our contributing guide on how to setup precommit hooks.
# Validate the results are close enough | ||
assert_expected(fused_loss, standard_loss, rtol=1e-2, atol=1e-2) | ||
|
||
def test_liger_fused_cross_entropy_loss_with_reshape(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For our SFTLoss type we can assume the input is "[bsz, seq_len, emb_dim]", so I don't think we need this second test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should have a second test, but it should be a distributed test. Same as the first test but with 4 gpus required and FSDP size 2 and TP size 2. If you need help on how to initialize the model that way I can give you the code.
|
||
|
||
class TestLigerFusedCrossEntropyLoss: | ||
def test_liger_fused_cross_entropy_loss(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this test requires cuda you should add the "@gpu_test(gpu_count=1)" decorator from "from tests.test_utils import gpu_test". Along with testing the loss value, I think it would be good to test a single forward and backward pass with opt step to ensure all the gradients are propagating back correctly too. You can use "fixed_init_model" (also from test_utils) as well to make it easier to initialize the model the same way each time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to add @pytest.mark.parametrize("compile", [False, True])
to the test and pass in compile as an argument on whether to call apply_compile_strategy
on the loss
orig_w = self.linear_projection.weight | ||
if isinstance(orig_w, DTensor): | ||
mesh, placements = orig_w.device_mesh, orig_w.placements | ||
w = orig_w.full_tensor().detach().clone().requires_grad_(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you tell me more about what you're doing here? Does the liger loss require you to detach the weight? Why detach it only to manually register that gradients get reapplied? Also, I don't think you'd want to register a hook every forward pass.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2741 +/- ##
==========================================
+ Coverage 60.64% 62.74% +2.09%
==========================================
Files 428 431 +3
Lines 26091 26479 +388
==========================================
+ Hits 15823 16613 +790
+ Misses 10268 9866 -402 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@pbontrager Hey thanks for an indepth review! I have tried to resolve most of the issues that were coming earlier however for the distributed test would love if you can help a little with the code as well as testing it cause I dont have a distributed gpu setup. |
# self.forward = torch.compile( | ||
# self.forward, *args, **kwargs | ||
# ) | ||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to compile a liger kernel at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey it was added in response to @pbontrager
I think it would be good to add @pytest.mark.parametrize("compile", [False, True]) to the test and pass in compile as an argument on whether to call apply_compile_strategy on the loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, ok, now I see 🙂
So the problem is that the kernel is automatically compiled.
It was obvious from the README of their repo - there was no mentioning that a user needs to call torch.compile
manually, which leaves only one option. Yes, Liger provides custom optimized triton kernels, but without compilation they won't work.
So, after digging a bit of their codebase, here is how it works:
LigerFusedLinearCrossEntropyFunction
has custom forward and backward implementations. (Let's focus on forward variant for now.)- Inside it, a fused_linear_cross_entropy_forward is called ...
- ... which calls a triton kernel, that has triton.jit wrapper.
Of course, you can control even that with
with torch._dynamo.disable():
and a flag that is disabled by default and enabled in apply_compile_strategy
, but since Triton kernel code is not valid Python code for direct execution on a GPU without a compilation, there is no much sense in it.
Perhaps maybe control compilation of service code inside a loss code around this kernel, but I believe it won't make that much difference 🤔.
In other words, since it was never intended to use liger without a compilation, perhaps just skip this method without any warnings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed the warning and added a doc string that this is JIT compiled
try: | ||
import liger_kernel.ops.fused_linear_cross_entropy | ||
|
||
self.fused_linear_ce = liger_kernel.ops.fused_linear_cross_entropy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A dummy question: why did you deciced to go this route instead of a loss class as described in README: https://github.com/linkedin/Liger-Kernel?tab=readme-ov-file#3-compose-your-own-model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, you are right to point out that! I chose the lower-level ops approach for bias handling as one under transformers
don't have a option for bias parameter and during distributed training we need to handle DTensor. If thats not required I can replace it with the one in readme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it looks like forward method of the class also accepts bias, which by default is None.
Basically, there is nothing wrong with your approach, just the one with the class looks, at least to me, slightly cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the class version was just a thin wrapper around LigerFusedLinearCrossEntropyFunction, and our loss class is also a thin wrapper around the same functionality, it feels more right to me that we just directly call LigerFusedLinearCrossEntropyFunction and operate at the same abstraction level as the nn.Module you linked to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, fair enough.
Linear loss calls F.cross_entropy and this one calls a function - a lil bit of uniformity.
I think it makes more sense to reshape I've tested a distributed LoRA finetune of llama3.1-8B with those changes, and it seems to work fine, amount of reserved memory was reduced and the difference in loss was minimal. |
@intervitens Hey thanks for checking out the distributed training! Glad that it works well. I have incorporated your changes. It would also be very helpful if you can provide an initial starting for adding the distributed test. |
batch_size, seq_len, emb_dim = hidden_states.shape | ||
hidden_states = hidden_states.reshape( | ||
-1, emb_dim | ||
) # [batch_size*seq_len, emb_dim] | ||
targets = targets.reshape(-1) # [batch_size*seq_len] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like since you don't reuse B, T, C values anywhere, it could be done simpler:
batch_size, seq_len, emb_dim = hidden_states.shape | |
hidden_states = hidden_states.reshape( | |
-1, emb_dim | |
) # [batch_size*seq_len, emb_dim] | |
targets = targets.reshape(-1) # [batch_size*seq_len] | |
hidden_states = hidden_states.flatten(0, 1) # (batch_size*seq_len, hidden_size) | |
targets = targets.flatten() # (batch_size*seq_len) |
None, # softcap | ||
False, # return_z_loss | ||
) | ||
if total_elements == 0: | ||
return loss | ||
return loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, basically, return loss
regardless? 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for doing all the updates! I think it's close now. I'm going to help with the unit tests and then once you have a chance to make any changes based on Andrei's comments, we should be good to land.
pyproject.toml
Outdated
@@ -61,6 +62,11 @@ dev = [ | |||
"urllib3<2.0.0", | |||
"wandb", | |||
"expecttest", | |||
# Triton: | |||
"triton>=2.3.1 ; platform_system != 'Windows'", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this dependency explicitly needed? Pytorch already includes triton I believe.
b.register_hook(_scatter_b) | ||
self._b_hook_registered = True | ||
|
||
loss, _ = self.fused_linear_ce.LigerFusedLinearCrossEntropyFunction.apply( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: couldn't self.fused_linear_ce = LigerFusedLinearCrossEntropyFunction
in your init and then here you'd just have self.fused_linear_ce.apply(...)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more nit: could we simplify the arguments list, since most of the values that are provided are actually equal to the default ones?
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
@staticmethod
@amp_custom_fwd
def forward(
ctx,
_input,
weight,
target,
bias=None,
ce_weight=None,
ignore_index=-100,
lse_square_scale=0.0,
label_smoothing=0.0,
reduction="mean",
softcap=None,
return_z_loss: bool = False,
):
if isinstance(w, DTensor): | ||
mesh, placements = w.device_mesh, w.placements | ||
w = w.full_tensor() | ||
if not hasattr(self, "_w_hook_registered"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think "full_tensor" handles gradient placement and we don't need to do this ourselves link. I can test removing this though.
try: | ||
import liger_kernel.ops.fused_linear_cross_entropy | ||
|
||
self.fused_linear_ce = liger_kernel.ops.fused_linear_cross_entropy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the class version was just a thin wrapper around LigerFusedLinearCrossEntropyFunction, and our loss class is also a thin wrapper around the same functionality, it feels more right to me that we just directly call LigerFusedLinearCrossEntropyFunction and operate at the same abstraction level as the nn.Module you linked to.
from torchtune.training.seed import set_seed | ||
|
||
|
||
@gpu_test(gpu_count=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me edit and push some changes to these tests. I can run them myself then for you.
@pbontrager I have simplifed distributed handling as well done the refactoring as per suggested changes. Would be greatful if you can help me add tests for distributed. |
Hey any updates? |
@mananchawla2005 I've gotten the unit tests working to test FSDP + TP for a single training step with the Liger loss. I haven't pushed the changes to your PR yet because the DTensor backward hook doesn't seem to work correctly so I'm trying to fix that and get the test to pass. I should be able to get back to this and get something to you by the end of this week. |
I pushed the tests with some changes here but the tests aren't passing in the distributed case. I've tried playing around with getting the full tensor in a full TP setting and full FSDP setting (the default test is a mix of both) but I'm still getting numerical differences. @ebsmothers do you have any ideas here? |
PR: Add LigerFusedCrossEntropyLoss
Context
What is the purpose of this PR? Is it to
Closes #2692
Changelog
LigerFusedCrossEntropyLoss
class that provides memory-efficient cross entropy loss using fused CUDA kernels usingliger-kernels
Test plan
F.cross_entropy
UX
Example usage in docstring:
The implementation provides better performance and memory efficiency compared to the chunked
LinearCrossEntropyLoss
by:All tests verify numerical correctness against PyTorch's native cross entropy within expected tolerances.
Hi PyTorch Team,
This is my first PR to a machine learning project, and I’ve tried to ensure the code is correct and well-structured. I’ve implemented the functionality for LigerCeLoss and included a test case that verifies its behavior with both masked and reshaped inputs.
Due to hardware limitations, I wasn't able to fully validate the distributed functionality or run the entire test suite across multiple GPUs. However, I’ve implemented support for DTensor-based weights, hidden states, and targets, and included the logic for gather → fuse → scatter to handle gradient computation for sharded weights in distributed settings.
Looking forward to your feedback! @pbontrager @joecummings