diff --git a/train.py b/train.py index de6aad7..54a868a 100644 --- a/train.py +++ b/train.py @@ -547,18 +547,9 @@ def main(config): ) accelerator.end_training() -# if __name__ == "__main__": -# parser = argparse.ArgumentParser() -# parser.add_argument("--config", type=str, default='/remote-home/lzwang/projects/MotionInversion/configs/config.yaml') -# args = parser.parse_args() - -# # Load and merge configurations -# config = OmegaConf.load(args.config) -# main(config) - if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default='/remote-home/lzwang/projects/MotionInversion/configs/config.yaml') + parser.add_argument("--config", type=str, default='./configs/config.yaml') parser.add_argument("--single_video_path", type=str) parser.add_argument("--prompts", type=str, help="JSON string of prompts") args = parser.parse_args() @@ -575,6 +566,4 @@ def main(config): if args.prompts: config.val.prompt = json.loads(args.prompts) - - main(config)