Skip to content

Commit

Permalink
fix input problem
Browse files Browse the repository at this point in the history
  • Loading branch information
neuralchen committed Apr 21, 2022
1 parent 1bbc1ef commit b893316
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/guidance/preparation.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pip install insightface==0.2.1 onnxruntime moviepy
- We use the face parsing from **[face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch)** for image postprocessing. Please download the relative file and place it in ./parsing_model/checkpoint from [this link](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view).
- The pytorch and cuda versions above are most recommanded. They may vary.
- Using insightface with different versions is not recommanded. Please use this specific version.
- These settings are tested valid on both Windows and Ununtu.
- These settings are tested valid on both Windows and Ubuntu.

### Pretrained model
There are two archive files in the drive: **checkpoints.zip** and **arcface_checkpoint.tar**
Expand Down
13 changes: 9 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Created Date: Monday December 27th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 20th April 2022 6:33:30 pm
# Last Modified: Thursday, 21st April 2022 6:21:17 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
Expand All @@ -29,6 +29,8 @@
from models.projected_model import fsModel
from data.data_loader_Swapping import GetLoader

def str2bool(v):
return v.lower() in ('true')

class TrainOptions:
def __init__(self):
Expand All @@ -39,7 +41,10 @@ def initialize(self):
self.parser.add_argument('--name', type=str, default='simswap', help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--gpu_ids', default='0')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--isTrain', type=bool, default=True)
self.parser.add_argument('--isTrain', type=str2bool, default='True')

# parser.add_argument('--use_tensorboard', type=str2bool, default='True',
# choices=['True', 'False'], help='enable the tensorboard')

# input/output sizes
self.parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
Expand All @@ -57,8 +62,8 @@ def initialize(self):
self.parser.add_argument('--niter_decay', type=int, default=10000, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.0, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.0004, help='initial learning rate for adam')
self.parser.add_argument("--Gdeep",type=bool,default=False)
self.parser.add_argument("--train_simswap",type=bool,default=True)
self.parser.add_argument('--Gdeep', type=str2bool, default='False')
self.parser.add_argument('--train_simswap', type=str2bool, default='True')

# for discriminators
self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
Expand Down

0 comments on commit b893316

Please sign in to comment.