From 85612a6cc0eb674382aefc6ef6d71ff32f9d0e67 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Thu, 29 Aug 2024 14:10:11 +0530 Subject: [PATCH] fix: remove fire ported from Hari's PR #303 Signed-off-by: Mehant Kammakomati Signed-off-by: Harikrishnan Balagopal --- pyproject.toml | 1 - tuning/sft_trainer.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e31192470..2675f49b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,6 @@ dependencies = [ "trl>=0.9.3,<1.0", "peft>=0.8.0,<0.13", "datasets>=2.15.0,<3.0", -"fire>=0.5.0,<1.0", "simpleeval>=0.9.13,<1.0", ] diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index b5e6cb62e..36690657f 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -37,7 +37,6 @@ ) from transformers.utils import is_accelerate_available from trl import SFTConfig, SFTTrainer -import fire import transformers # Local @@ -508,7 +507,7 @@ def parse_arguments(parser, json_config=None): ) -def main(**kwargs): # pylint: disable=unused-argument +def main(): parser = get_parser() logger = logging.getLogger() job_config = get_json_config() @@ -629,4 +628,4 @@ def main(**kwargs): # pylint: disable=unused-argument if __name__ == "__main__": - fire.Fire(main) + main()