-
Notifications
You must be signed in to change notification settings - Fork 459
Closed
Description
🐛 Describe the bug
after upgrade version from 0.5.5 to 0.5.6, when sft llama3.1-8b use trl with DataCollatorForCompletionOnlyLM padding_free=True, it has error as follows:
[rank0]: File "xxx/python3.11/site-packages/transformers/trainer.py", line 2245, in train
[rank0]: return inner_training_loop(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/transformers/trainer.py", line 2556, in _inner_training_loop
[rank0]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/transformers/trainer.py", line 3718, in training_step
[rank0]: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/trl/trainer/sft_trainer.py", line 495, in compute_loss
[rank0]: (loss, outputs) = super().compute_loss(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/transformers/trainer.py", line 3783, in compute_loss
[rank0]: outputs = model(**inputs)
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]: ret_val = func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/deepspeed/runtime/engine.py", line 2030, in forward
[rank0]: loss = self.module(*inputs, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]: return inner()
[rank0]: ^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/torch/nn/modules/module.py", line 1793, in inner
[rank0]: result = forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/liger_kernel/transformers/model/llama.py", line 216, in lce_forward
[rank0]: loss = LigerForCausalLMLoss(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/liger_kernel/transformers/model/loss_utils.py", line 49, in LigerForCausalLMLoss
[rank0]: loss = fixed_fused_linear_cross_entropy(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "xxx/python3.11/site-packages/liger_kernel/transformers/model/loss_utils.py", line 15, in fixed_fused_linear_cross_entropy
[rank0]: loss = F.liger_fused_linear_cross_entropy(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: TypeError: liger_fused_linear_cross_entropy() got an unexpected keyword argument 'cu_seq_lens_q'
when version==0.5.5, sft llama3.1-8b it's OK
when version==0.5.6 sft gemma3-4b it's OK
use the same sft code
Reproduce
No response
Versions
Environment Report:
Operating System: Linux-5.15.0-73-generic-x86_64-with-glibc2.35
Python version: 3.11.9
Liger Kernel version: 0.5.6
PyTorch version: 2.6.0+cu124
CUDA version: 12.4
HIP(ROCm) version: Not available
Triton version: 3.2.0
Transformers version: 4.50.3
XPU version: XPU Not Available
Metadata
Metadata
Assignees
Labels
No labels