Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monkeypatch for Phi3 #76

Merged
merged 13 commits into from
Aug 27, 2024
Merged

Conversation

tyler-romero
Copy link
Contributor

@tyler-romero tyler-romero commented Aug 24, 2024

Summary

Add a new monkeypatch function to support patching Huggingface's Phi3 implementation with Liger Kernels.

Phi3 has its own MLP implementation (Phi3MLP) so a LigerPhi3SwiGLUMLP implementation that leverages LigerSiLUMulFunction is provided as well.

Testing Done

  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Convergence test added (and passing on my 4090) for a minimodel based on Phi3 patched with liger kernels.
All tests passing.

Questions for Discussion

Apparently Phi3 was only added in transformers v4.41, but the lowest supported version of transformers in Liger-Kernel is 4.40.1. Additionally, only more recently has sdpa been supported in HF Phi3. Thoughts? Should I leave the transformers dependency version as-is?

Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -153,6 +153,8 @@ loss.backward()
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
| Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the whitespace modifications here, this is the only change in this file

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for helping on the housekeeping! <3

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to implement for FusedLinearCrossEntropy in the next PR and also create a issue to track? Essentially add a method patch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah happy to do that as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#98

setup.py Outdated Show resolved Hide resolved
@@ -38,3 +38,27 @@ def __init__(self, config):
def forward(self, x):

return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))


class LigerPhi3SwiGLUMLP(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Llama has its own MLP implementation here, but its named very generally. I went for a model-specific name here, but open to suggestions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. this is necessary

@tyler-romero tyler-romero marked this pull request as ready for review August 24, 2024 22:37
Copy link
Collaborator

@yundai424 yundai424 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this hyper speed PR woo hoo

test/transformers/test_swiglu.py Outdated Show resolved Hide resolved
Makefile Show resolved Hide resolved
@ByronHsu
Copy link
Collaborator

ByronHsu commented Aug 26, 2024

Apparently Phi3 was only added in transformers v4.41, but the lowest supported version of transformers in Liger-Kernel is 4.40.1. Additionally, only more recently has huggingface/transformers#32457. Thoughts? Should I leave the transformers dependency version as-is?

yeah let's bump to 4.41. sdpa seems not a hard requirement since most users use flash attn

@ByronHsu
Copy link
Collaborator

we can merge after transformers bump and fix conflict

@tyler-romero
Copy link
Contributor Author

make test && make test-convergence
python -m pytest --disable-warnings test/ --ignore=test/convergence
=================================================================================================== test session starts ====================================================================================================
platform linux -- Python 3.10.13, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/tromero/workspace/Liger-Kernel
plugins: devtools-0.12.2
collected 138 items                                                                                                                                                                                                        

test/transformers/test_cross_entropy.py ........................................................ss                                                                                                                   [ 42%]
test/transformers/test_fused_linear_cross_entropy.py ......                                                                                                                                                          [ 46%]
test/transformers/test_geglu.py ........                                                                                                                                                                             [ 52%]
test/transformers/test_rms_norm.py ................................                                                                                                                                                  [ 75%]
test/transformers/test_rope.py ............                                                                                                                                                                          [ 84%]
test/transformers/test_swiglu.py ................                                                                                                                                                                    [ 95%]
test/transformers/test_trainer_integration.py ...                                                                                                                                                                    [ 97%]
test/transformers/test_transformers_monkey_patch.py .                                                                                                                                                                [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                                                                                           [100%]

======================================================================================== 136 passed, 2 skipped in 77.79s (0:01:17) =========================================================================================
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence
=================================================================================================== test session starts ====================================================================================================
platform linux -- Python 3.10.13, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/tromero/workspace/Liger-Kernel
plugins: devtools-0.12.2
collected 12 items                                                                                                                                                                                                         

test/convergence/test_mini_models.py F.........                                                                                                                                                                      [ 83%]
test/convergence/test_mini_models_no_logits.py ..                                                                                                                                                                    [100%]

========================================================================================================= FAILURES =========================================================================================================
_____________________________________________________________________ test_mini_model[mini_gemma-32-0.0001-dtype0-1e-08-1e-05-0.005-1e-05-0.005-1e-05] _____________________________________________________________________

All relevant tests passing, mini_gemma test convergence is failing on my 4090 now though, perhaps related to the recent PR (#85):

>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 3
E           Mismatch at index (0, 29): tensor1[(0, 29)] = 0.3853762447834015, tensor2[(0, 29)] = 0.38537222146987915
E           Mismatch at index (0, 30): tensor1[(0, 30)] = 0.7438472509384155, tensor2[(0, 30)] = 0.7438388466835022
E           Mismatch at index (0, 31): tensor1[(0, 31)] = 0.5316240191459656, tensor2[(0, 31)] = 0.5316107869148254

@tyler-romero
Copy link
Contributor Author

I think I'll need a review from @yundai424 as well to bypass the "changes requested" flag

@lancerts
Copy link
Collaborator

@yundai424 can you take a look? ty

Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm and we can fix gemma in #92

@yundai424 yundai424 merged commit ee2dacb into linkedin:main Aug 27, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants