Skip to content

Commit

Permalink
fix save model bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
neuralchen committed Apr 21, 2022
1 parent b893316 commit 7ed12d2
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def save_optim(self, network, network_label, epoch_label, gpu_ids=None):
torch.save(network.state_dict(), save_path)

# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
def save_network(self, network, network_label, epoch_label, gpu_ids=None):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
Expand Down
5 changes: 4 additions & 1 deletion models/fs_networks_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
padding_type='reflect'):
assert (n_blocks >= 0)
super(Generator_Adain_Upsample, self).__init__()

activation = nn.ReLU(True)

self.deep = deep

self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
norm_layer(64), activation)
### downsample
Expand All @@ -101,6 +103,7 @@ def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
norm_layer(256), activation)
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)

if self.deep:
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)
Expand Down
4 changes: 2 additions & 2 deletions models/projected_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Created Date: Wednesday January 12th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Wednesday, 20th April 2022 6:34:47 pm
# Last Modified: Thursday, 21st April 2022 8:13:37 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
Expand Down Expand Up @@ -94,7 +94,7 @@ def cosin_metric(self, x1, x2):
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch)
self.save_network(self.netD, 'D', which_epoch)
self.save_optim(self.optimizer_G, 'G', which_epoch,)
self.save_optim(self.optimizer_G, 'G', which_epoch)
self.save_optim(self.optimizer_D, 'D', which_epoch)
'''if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)'''
Expand Down
5 changes: 1 addition & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Created Date: Monday December 27th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Thursday, 21st April 2022 6:21:17 pm
# Last Modified: Thursday, 21st April 2022 8:10:05 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
Expand Down Expand Up @@ -43,9 +43,6 @@ def initialize(self):
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--isTrain', type=str2bool, default='True')

# parser.add_argument('--use_tensorboard', type=str2bool, default='True',
# choices=['True', 'False'], help='enable the tensorboard')

# input/output sizes
self.parser.add_argument('--batchSize', type=int, default=16, help='input batch size')

Expand Down

0 comments on commit 7ed12d2

Please sign in to comment.