From 6a2e03d798e1a676ee11b167ce9c109074f51bb1 Mon Sep 17 00:00:00 2001 From: chenxuanhong Date: Fri, 22 Apr 2022 00:15:53 +0800 Subject: [PATCH] Update train.py --- train.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index 636f0112..03dfb58f 100644 --- a/train.py +++ b/train.py @@ -5,20 +5,18 @@ # Created Date: Monday December 27th 2021 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com -# Last Modified: Thursday, 21st April 2022 10:59:07 pm +# Last Modified: Friday, 22nd April 2022 12:15:47 am # Modified By: Chen Xuanhong # Copyright (c) 2021 Shanghai Jiao Tong University ############################################################# import os import time -import wandb import random import argparse import numpy as np import torch -import torch.nn as nn import torch.nn.functional as F from torch.backends import cudnn import torch.utils.tensorboard as tensorboard @@ -52,9 +50,10 @@ def initialize(self): # for training self.parser.add_argument('--dataset', type=str, default="/path/to/VGGFace2", help='path to the face swapping dataset') - self.parser.add_argument('--continue_train', type=bool, default=False, help='continue training: load the latest model') + self.parser.add_argument('--continue_train', type=str2bool, default='True', help='continue training: load the latest model') + # self.parser.add_argument('--Gdeep', type=str2bool, default='False') self.parser.add_argument('--load_pretrain', type=str, default='checkpoints', help='load the pretrained model from the specified location') - self.parser.add_argument('--which_epoch', type=str, default='800000', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--which_epoch', type=str, default='320', help='which epoch to load? set to latest to use latest cached model') self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate') self.parser.add_argument('--niter_decay', type=int, default=10000, help='# of iter to linearly decay learning rate to zero')