Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FusedLinearCrossEntropy to Gemma #111

Merged
merged 3 commits into from
Aug 27, 2024
Merged

Conversation

Luke-Chesley
Copy link
Contributor

@Luke-Chesley Luke-Chesley commented Aug 26, 2024

Summary

This PR adds FusedLinearCrossEntropy support for gemma to resolve
issue #101.

Details

I based the code in this PR off of #93 which does the same thing for Mistral.

Since the parameters fused_linear_cross_entropy and cross_entropy cannot both be true, fused_linear_cross_entropy has to be explicitly set to False in test_mini_models.py, to avoid failing convergence
test due to setting both True.
In test_mini_models_no_logits.py, I add two extra tests for gemma to ensure that FusedLinearCrossEntropy part in gemma is tested.

Testing Done

  • Hardware Type: NVIDIA A100-SXM4-80GB
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence
Fixing ~/Liger-Kernel/test/convergence/test_mini_models_no_logits.py
Skipped 1 files
reformatted /home/luke/Desktop/Liger-Kernel/src/liger_kernel/transformers/model/gemma.py
reformatted /home/luke/Desktop/Liger-Kernel/src/liger_kernel/transformers/monkey_patch.py
reformatted /home/luke/Desktop/Liger-Kernel/test/convergence/test_mini_models_no_logits.py

All done! ✨ 🍰 ✨
3 files reformatted, 51 files left unchanged.
make: *** [Makefile:14: checkstyle] Error 1
luke@luke:~/Desktop/Liger-Kernel$ 


root@C.12106734:~/Liger-Kernel$ make test
python -m pytest --disable-warnings test/ --ignore=test/convergence
======================================================= test session starts =======================================================
platform linux -- Python 3.11.9, pytest-8.3.2, pluggy-1.5.0
rootdir: /root/Liger-Kernel
plugins: hypothesis-6.108.4, anyio-4.4.0
collected 130 items                                                                                                               

test/transformers/test_cross_entropy.py ..........................................................                          [ 44%]
test/transformers/test_fused_linear_cross_entropy.py ......                                                                 [ 49%]
test/transformers/test_geglu.py ........                                                                                    [ 55%]
test/transformers/test_rms_norm.py ................................                                                         [ 80%]
test/transformers/test_rope.py ............                                                                                 [ 89%]
test/transformers/test_swiglu.py ........                                                                                   [ 95%]
test/transformers/test_trainer_integration.py ...                                                                           [ 97%]
test/transformers/test_transformers_monkey_patch.py .                                                                       [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                  [100%]

====================================================== 130 passed in 44.12s =======================================================



root@C.12106734:~/Liger-Kernel$ make test-convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence
======================================================= test session starts =======================================================
platform linux -- Python 3.11.9, pytest-8.3.2, pluggy-1.5.0
rootdir: /root/Liger-Kernel
plugins: hypothesis-6.108.4, anyio-4.4.0
collected 14 items                                                                                                                

test/convergence/test_mini_models.py ........                                                                               [ 57%]
test/convergence/test_mini_models_no_logits.py ......                                                                       [100%]

================================================= 14 passed in 125.51s (0:02:05) ==================================================


	modified:   src/liger_kernel/transformers/monkey_patch.py
	modified:   test/convergence/test_mini_models.py
	modified:   test/convergence/test_mini_models_no_logits.py
rms_norm: bool = True,
geglu: bool = True,
swiglu: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unused arg

	modified:   src/liger_kernel/transformers/monkey_patch.py
    modified:   test/convergence/test_mini_models_no_logits.py
@Luke-Chesley
Copy link
Contributor Author

Luke-Chesley commented Aug 27, 2024

changed how the kwargs are passed in test_mini_models_no_logits.py to how they are passed in test_mini_models.py to fix the unused arg. all tests still pass

@@ -143,6 +145,8 @@ def apply_liger_kernel_to_gemma(
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if geglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if fused_linear_cross_entropy:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help to add an assertion here to make sure fused_linear_cross_entropy and cross_entropy are not True together, just like in llama monkey patch function? (looks like we dropped this in qwen2 too, would be helpful if you could add it for qwen too!)

	modified:   src/liger_kernel/transformers/monkey_patch.py
Copy link
Collaborator

@yundai424 yundai424 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks!

Copy link
Collaborator

@qingquansong qingquansong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ByronHsu ByronHsu merged commit a49e421 into linkedin:main Aug 27, 2024
@DocShotgun
Copy link

DocShotgun commented Aug 27, 2024

Just had a quick question for clarification - is this for the original Gemma, or Gemma2 - or are the changes functional for both? I noticed the Liger Kernel README refers to the API for patching Gemma 2 as liger_kernel.transformers.apply_liger_kernel_to_gemma.

EDIT: I was informed that this code doesn't seem to include the softcapping used in Gemma2's forward:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py#L1054-L1057
So perhaps it's only fully functional for the original Gemma?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants