From b893316e41bfef06b2b35159efa68570ea7fcec3 Mon Sep 17 00:00:00 2001 From: chenxuanhong Date: Thu, 21 Apr 2022 18:23:47 +0800 Subject: [PATCH] fix input problem --- docs/guidance/preparation.md | 2 +- train.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/guidance/preparation.md b/docs/guidance/preparation.md index 525f619b..56dca989 100644 --- a/docs/guidance/preparation.md +++ b/docs/guidance/preparation.md @@ -16,7 +16,7 @@ pip install insightface==0.2.1 onnxruntime moviepy - We use the face parsing from **[face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch)** for image postprocessing. Please download the relative file and place it in ./parsing_model/checkpoint from [this link](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view). - The pytorch and cuda versions above are most recommanded. They may vary. - Using insightface with different versions is not recommanded. Please use this specific version. -- These settings are tested valid on both Windows and Ununtu. +- These settings are tested valid on both Windows and Ubuntu. ### Pretrained model There are two archive files in the drive: **checkpoints.zip** and **arcface_checkpoint.tar** diff --git a/train.py b/train.py index 9e77695e..93df8685 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ # Created Date: Monday December 27th 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Wednesday, 20th April 2022 6:33:30 pm +# Last Modified: Thursday, 21st April 2022 6:21:17 pm # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# @@ -29,6 +29,8 @@ from models.projected_model import fsModel from data.data_loader_Swapping import GetLoader +def str2bool(v): + return v.lower() in ('true') class TrainOptions: def __init__(self): @@ -39,7 +41,10 @@ def initialize(self): self.parser.add_argument('--name', type=str, default='simswap', help='name of the experiment. It decides where to store samples and models') self.parser.add_argument('--gpu_ids', default='0') self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') - self.parser.add_argument('--isTrain', type=bool, default=True) + self.parser.add_argument('--isTrain', type=str2bool, default='True') + + # parser.add_argument('--use_tensorboard', type=str2bool, default='True', + # choices=['True', 'False'], help='enable the tensorboard') # input/output sizes self.parser.add_argument('--batchSize', type=int, default=16, help='input batch size') @@ -57,8 +62,8 @@ def initialize(self): self.parser.add_argument('--niter_decay', type=int, default=10000, help='# of iter to linearly decay learning rate to zero') self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam') self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam') - self.parser.add_argument("--Gdeep",type=bool,default=False) - self.parser.add_argument("--train_simswap",type=bool,default=True) + self.parser.add_argument('--Gdeep', type=str2bool, default='False') + self.parser.add_argument('--train_simswap', type=str2bool, default='True') # for discriminators self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')