Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
Fix model name checking
  • Loading branch information
lightmatmul authored Apr 16, 2024
1 parent 367df45 commit 1b57991
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions minigemini/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,15 +1043,15 @@ def train(attn_implementation=None):
))

if model_args.vision_tower is not None:
if "mistral" in model_args.model_name_or_path:
if "mistral" in model_args.model_name_or_path.lower():
model = MiniGeminiMistralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
**bnb_model_from_pretrained_args
)
elif "mixtral" in model_args.model_name_or_path:
elif "mixtral" in model_args.model_name_or_path.lower():
model = MiniGeminiMixtralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
Expand All @@ -1061,7 +1061,7 @@ def train(attn_implementation=None):
)
from deepspeed.utils import set_z3_leaf_modules
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
elif "gemma" in model_args.model_name_or_path:
elif "gemma" in model_args.model_name_or_path.lower():
model = MiniGeminiGemmaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
Expand Down

0 comments on commit 1b57991

Please sign in to comment.