Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
neuralchen committed Apr 21, 2022
1 parent 7ed12d2 commit 9f3daca
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 12 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ If you find this project useful, please star it. It is the greatest appreciation
Download the dataset from [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ).

The training script is slightly different from the original version, e.g., we replace the patch discriminator with the projected discriminator, which saves a lot of hardware overhead and achieves slightly better results.
In order to ensure normal training, the batch size must be greater than 1.

- Train 256 models
```
Expand Down
7 changes: 0 additions & 7 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,6 @@ def save_optim(self, network, network_label, epoch_label, gpu_ids=None):
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.state_dict(), save_path)

# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids=None):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
network.cuda()

# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label, save_dir=''):
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Created Date: Monday December 27th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 21st April 2022 8:10:05 pm
# Last Modified: Thursday, 21st April 2022 10:36:48 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
Expand Down Expand Up @@ -44,7 +44,7 @@ def initialize(self):
self.parser.add_argument('--isTrain', type=str2bool, default='True')

# input/output sizes
self.parser.add_argument('--batchSize', type=int, default=16, help='input batch size')
self.parser.add_argument('--batchSize', type=int, default=2, help='input batch size')

# for displays
self.parser.add_argument('--tag', type=str, default='simswap')
Expand All @@ -69,9 +69,9 @@ def initialize(self):

self.parser.add_argument("--Arc_path", type=str, default='arcface_model/arcface_checkpoint.tar', help="run ONNX model via TRT")
self.parser.add_argument("--total_step", type=int, default=1000000, help='total training step')
self.parser.add_argument("--log_frep", type=int, default=250, help='frequence for printing log information')
self.parser.add_argument("--sample_freq", type=int, default=1000, help='frequence for sampling')
self.parser.add_argument("--model_freq", type=int, default=10000, help='frequence for saving the model')
self.parser.add_argument("--log_frep", type=int, default=10, help='frequence for printing log information')
self.parser.add_argument("--sample_freq", type=int, default=30, help='frequence for sampling')
self.parser.add_argument("--model_freq", type=int, default=40, help='frequence for saving the model')



Expand Down

0 comments on commit 9f3daca

Please sign in to comment.