-
Notifications
You must be signed in to change notification settings - Fork 277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FusedLinearCrossEntropy to Gemma #111
Conversation
modified: src/liger_kernel/transformers/monkey_patch.py modified: test/convergence/test_mini_models.py modified: test/convergence/test_mini_models_no_logits.py
rms_norm: bool = True, | ||
geglu: bool = True, | ||
swiglu: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: unused arg
modified: src/liger_kernel/transformers/monkey_patch.py modified: test/convergence/test_mini_models_no_logits.py
changed how the kwargs are passed in test_mini_models_no_logits.py to how they are passed in test_mini_models.py to fix the unused arg. all tests still pass |
@@ -143,6 +145,8 @@ def apply_liger_kernel_to_gemma( | |||
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss | |||
if geglu: | |||
modeling_gemma.GemmaMLP = LigerGEGLUMLP | |||
if fused_linear_cross_entropy: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you help to add an assertion here to make sure fused_linear_cross_entropy
and cross_entropy
are not True together, just like in llama monkey patch function? (looks like we dropped this in qwen2 too, would be helpful if you could add it for qwen too!)
modified: src/liger_kernel/transformers/monkey_patch.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Just had a quick question for clarification - is this for the original Gemma, or Gemma2 - or are the changes functional for both? I noticed the Liger Kernel README refers to the API for patching Gemma 2 as EDIT: I was informed that this code doesn't seem to include the softcapping used in Gemma2's forward: |
Summary
This PR adds FusedLinearCrossEntropy support for gemma to resolve
issue #101.
Details
I based the code in this PR off of #93 which does the same thing for Mistral.
Since the parameters fused_linear_cross_entropy and cross_entropy cannot both be true, fused_linear_cross_entropy has to be explicitly set to False in test_mini_models.py, to avoid failing convergence
test due to setting both True.
In test_mini_models_no_logits.py, I add two extra tests for gemma to ensure that FusedLinearCrossEntropy part in gemma is tested.
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence