Skip to content

Commit

Permalink
Integrate Phi3 LCE support
Browse files Browse the repository at this point in the history
  • Loading branch information
DocShotgun authored Aug 28, 2024
1 parent 89e0c2c commit 5bc5d51
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/axolotl/integrations/liger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def pre_model_load(self, cfg):
)

elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3

if cfg.liger_rope:
Expand All @@ -178,6 +181,4 @@ def pre_model_load(self, cfg):
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Phi 3."
)
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

0 comments on commit 5bc5d51

Please sign in to comment.