Skip to content

Commit

Permalink
move import to their appropriate conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Aug 28, 2024
1 parent e82f1a0 commit 89e0c2c
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/axolotl/integrations/liger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
Expand All @@ -45,6 +42,9 @@ def get_input_args(self):

def pre_model_load(self, cfg):
if cfg.model_config_type == "llama":
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
from transformers.models.llama import modeling_llama

if cfg.liger_rope:
Expand All @@ -59,6 +59,9 @@ def pre_model_load(self, cfg):
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward

elif cfg.model_config_type == "mistral":
from liger_kernel.transformers.model.mistral import (
lce_forward as mistral_lce_forward,
)
from transformers.models.mistral import modeling_mistral

if cfg.liger_rope:
Expand All @@ -73,6 +76,9 @@ def pre_model_load(self, cfg):
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward

elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma

if cfg.liger_rope:
Expand Down

0 comments on commit 89e0c2c

Please sign in to comment.