Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix liger plugin load issues #1876

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix liger plugin load issues
  • Loading branch information
tmm1 committed Aug 27, 2024
commit 56d58ffe622d1e85f35c85029c67b007a8d9c2ad
8 changes: 5 additions & 3 deletions src/axolotl/integrations/liger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.model.llama import lce_forward
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
Expand Down Expand Up @@ -54,7 +53,7 @@ def pre_model_load(self, cfg):
if cfg.liger_cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
elif cfg.liger_fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = lce_forward
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward

elif cfg.model_config_type == "mistral":
from transformers.models.mistral import modeling_mistral
Expand Down Expand Up @@ -105,6 +104,9 @@ def pre_model_load(self, cfg):
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward

elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2

if cfg.liger_rope:
Expand Down
Loading