Skip to content

Commit

Permalink
more reasonable pipeline and training on multigpu and testing on one …
Browse files Browse the repository at this point in the history
…gpu bug
  • Loading branch information
Prinsphield committed Sep 18, 2018
1 parent d78608e commit 2a09a3b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 51 deletions.
113 changes: 63 additions & 50 deletions ELEGANT.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,38 @@ def __init__(self, args,
self.adv_criterion = torch.nn.BCELoss()
self.recon_criterion = torch.nn.MSELoss()

self.set_mode_and_gpu()
self.restore_from_file()
self.set_mode_and_gpu()


def restore_from_file(self):
if self.restore is not None:
ckpt_file_enc = os.path.join(self.config.model_dir, 'Enc_iter_{:06d}.pth'.format(self.restore))
assert os.path.exists(ckpt_file_enc)
ckpt_file_dec = os.path.join(self.config.model_dir, 'Dec_iter_{:06d}.pth'.format(self.restore))
assert os.path.exists(ckpt_file_dec)
if self.gpu:
self.Enc.load_state_dict(torch.load(ckpt_file_enc), strict=False)
self.Dec.load_state_dict(torch.load(ckpt_file_dec), strict=False)
else:
self.Enc.load_state_dict(torch.load(ckpt_file_enc, map_location='cpu'), strict=False)
self.Dec.load_state_dict(torch.load(ckpt_file_dec, map_location='cpu'), strict=False)

if self.mode == 'train':
ckpt_file_d1 = os.path.join(self.config.model_dir, 'D1_iter_{:06d}.pth'.format(self.restore))
assert os.path.exists(ckpt_file_d1)
ckpt_file_d2 = os.path.join(self.config.model_dir, 'D2_iter_{:06d}.pth'.format(self.restore))
assert os.path.exists(ckpt_file_d2)
if self.gpu:
self.D1.load_state_dict(torch.load(ckpt_file_d1), strict=False)
self.D2.load_state_dict(torch.load(ckpt_file_d2), strict=False)
else:
self.D1.load_state_dict(torch.load(ckpt_file_d1, map_location='cpu'), strict=False)
self.D2.load_state_dict(torch.load(ckpt_file_d2, map_location='cpu'), strict=False)

self.start_step = self.restore + 1
else:
self.start_step = 1

def set_mode_and_gpu(self):
if self.mode == 'train':
Expand All @@ -49,6 +78,23 @@ def set_mode_and_gpu(self):
self.D1.train()
self.D2.train()

self.writer = SummaryWriter(self.config.log_dir)

self.optimizer_G = torch.optim.Adam(chain(self.Enc.parameters(), self.Dec.parameters()),
lr=self.config.G_lr, betas=(0.5, 0.999),
weight_decay=self.config.weight_decay)

self.optimizer_D = torch.optim.Adam(chain(self.D1.parameters(), self.D2.parameters()),
lr=self.config.D_lr, betas=(0.5, 0.999),
weight_decay=self.config.weight_decay)

self.G_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer_G, step_size=self.config.step_size, gamma=self.config.gamma)
self.D_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer_D, step_size=self.config.step_size, gamma=self.config.gamma)
if self.restore is not None:
for _ in range(self.restore):
self.D_lr_scheduler.step()
self.D_lr_scheduler.step()

if self.gpu:
with torch.cuda.device(0):
self.Enc.cuda()
Expand Down Expand Up @@ -80,29 +126,6 @@ def set_mode_and_gpu(self):
else:
raise NotImplementationError()

def restore_from_file(self):
if self.restore is not None:
ckpt_file_enc = os.path.join(self.config.model_dir, 'Enc_iter_{:06d}.pth'.format(self.restore))
assert os.path.exists(ckpt_file_enc)
self.Enc.load_state_dict(torch.load(ckpt_file_enc))

ckpt_file_dec = os.path.join(self.config.model_dir, 'Dec_iter_{:06d}.pth'.format(self.restore))
assert os.path.exists(ckpt_file_dec)
self.Dec.load_state_dict(torch.load(ckpt_file_dec))

if self.mode == 'train':
ckpt_file_d1 = os.path.join(self.config.model_dir, 'D1_iter_{:06d}.pth'.format(self.restore))
assert os.path.exists(ckpt_file_d1)
self.D1.load_state_dict(torch.load(ckpt_file_d1))

ckpt_file_d2 = os.path.join(self.config.model_dir, 'D2_iter_{:06d}.pth'.format(self.restore))
assert os.path.exists(ckpt_file_d2)
self.D2.load_state_dict(torch.load(ckpt_file_d2))

self.start_step = self.restore + 1
else:
self.start_step = 1

def tensor2var(self, tensors, volatile=False):
if not hasattr(tensors, '__iter__'): tensors = [tensors]
out = []
Expand All @@ -123,7 +146,6 @@ def get_attr_chs(self, encodings, attribute_id):
end = int(np.rint(per_chs * (attribute_id + 1)))
return encodings[:, start:end]


def forward_G(self):
self.z_A, self.A_skip = self.Enc(self.A, return_skip=True)
self.z_B, self.B_skip = self.Enc(self.B, return_skip=True)
Expand Down Expand Up @@ -241,35 +263,25 @@ def save_scalar_log(self):
self.writer.add_scalar(tag, value, self.step)

def save_model(self):
torch.save({key: val.cpu() for key, val in self.Enc.state_dict().items()}, os.path.join(self.config.model_dir, 'Enc_iter_{:06d}.pth'.format(self.step)))
torch.save({key: val.cpu() for key, val in self.Dec.state_dict().items()}, os.path.join(self.config.model_dir, 'Dec_iter_{:06d}.pth'.format(self.step)))
torch.save({key: val.cpu() for key, val in self.D1.state_dict().items()}, os.path.join(self.config.model_dir, 'D1_iter_{:06d}.pth'.format(self.step)))
torch.save({key: val.cpu() for key, val in self.D2.state_dict().items()}, os.path.join(self.config.model_dir, 'D2_iter_{:06d}.pth'.format(self.step)))
def reduced(key):
if key.startswith('module.'):
return key[7:]
else:
return key
torch.save({reduced(key): val.cpu() for key, val in self.Enc.state_dict().items()}, os.path.join(self.config.model_dir, 'Enc_iter_{:06d}.pth'.format(self.step)))
torch.save({reduced(key): val.cpu() for key, val in self.Dec.state_dict().items()}, os.path.join(self.config.model_dir, 'Dec_iter_{:06d}.pth'.format(self.step)))
torch.save({reduced(key): val.cpu() for key, val in self.D1.state_dict().items()}, os.path.join(self.config.model_dir, 'D1_iter_{:06d}.pth'.format(self.step)))
torch.save({reduced(key): val.cpu() for key, val in self.D2.state_dict().items()}, os.path.join(self.config.model_dir, 'D2_iter_{:06d}.pth'.format(self.step)))

def train(self):
self.writer = SummaryWriter(self.config.log_dir)

self.optimizer_G = torch.optim.Adam(chain(self.Enc.parameters(), self.Dec.parameters()),
lr=self.config.G_lr, betas=(0.5, 0.999),
weight_decay=self.config.weight_decay)

self.optimizer_D = torch.optim.Adam(chain(self.D1.parameters(), self.D2.parameters()),
lr=self.config.D_lr, betas=(0.5, 0.999),
weight_decay=self.config.weight_decay)

self.G_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer_G, step_size=self.config.step_size, gamma=self.config.gamma)
self.D_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer_D, step_size=self.config.step_size, gamma=self.config.gamma)

# start training
for step in range(self.start_step, 1 + self.config.max_iter):
self.step = step
for self.step in range(self.start_step, 1 + self.config.max_iter):
if self.step > self.start_step + 3: break;
self.G_lr_scheduler.step()
self.D_lr_scheduler.step()

for attribute_id in range(self.n_attributes):
self.attribute_id = attribute_id
A, y_A = next(self.dataset.gen(attribute_id, True))
B, y_B = next(self.dataset.gen(attribute_id, False))
for self.attribute_id in range(self.n_attributes):
A, y_A = next(self.dataset.gen(self.attribute_id, True))
B, y_B = next(self.dataset.gen(self.attribute_id, False))
self.A, self.y_A, self.B, self.y_B = self.tensor2var([A, y_A, B, y_B])

# forward
Expand Down Expand Up @@ -299,7 +311,8 @@ def train(self):
if self.step % 100 == 0:
self.save_scalar_log()

if self.step % 2000 == 0:
# if self.step % 2000 == 0:
if 1:
self.save_model()

print('Finished Training!')
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ We provide four types of mode for testing. Let me explain all the parameters for

- `-a`: All attributes' names.
- `-r`: Restore checkpoint.
- `-g`: The gpu id(s) for testing.
- `-g`: The GPU id(s) for testing.
- Don't add this parameter to your shell command if you don't want to use gpu for testing.
- No more than 1 GPU should be specified during test, because 1 image cannot be split into multiple GPUs.
- `--swap`: Swap attribute of two images.
- `--linear`: Linear interpolation by adding or removing one certain attribute.
- `--matrix`: Matrix interpolation with respect to one or two attributes.
Expand Down
1 change: 1 addition & 0 deletions datasets

0 comments on commit 2a09a3b

Please sign in to comment.