diff --git a/ELEGANT.py b/ELEGANT.py index c023961..cce45d4 100644 --- a/ELEGANT.py +++ b/ELEGANT.py @@ -38,9 +38,38 @@ def __init__(self, args, self.adv_criterion = torch.nn.BCELoss() self.recon_criterion = torch.nn.MSELoss() - self.set_mode_and_gpu() self.restore_from_file() + self.set_mode_and_gpu() + + + def restore_from_file(self): + if self.restore is not None: + ckpt_file_enc = os.path.join(self.config.model_dir, 'Enc_iter_{:06d}.pth'.format(self.restore)) + assert os.path.exists(ckpt_file_enc) + ckpt_file_dec = os.path.join(self.config.model_dir, 'Dec_iter_{:06d}.pth'.format(self.restore)) + assert os.path.exists(ckpt_file_dec) + if self.gpu: + self.Enc.load_state_dict(torch.load(ckpt_file_enc), strict=False) + self.Dec.load_state_dict(torch.load(ckpt_file_dec), strict=False) + else: + self.Enc.load_state_dict(torch.load(ckpt_file_enc, map_location='cpu'), strict=False) + self.Dec.load_state_dict(torch.load(ckpt_file_dec, map_location='cpu'), strict=False) + + if self.mode == 'train': + ckpt_file_d1 = os.path.join(self.config.model_dir, 'D1_iter_{:06d}.pth'.format(self.restore)) + assert os.path.exists(ckpt_file_d1) + ckpt_file_d2 = os.path.join(self.config.model_dir, 'D2_iter_{:06d}.pth'.format(self.restore)) + assert os.path.exists(ckpt_file_d2) + if self.gpu: + self.D1.load_state_dict(torch.load(ckpt_file_d1), strict=False) + self.D2.load_state_dict(torch.load(ckpt_file_d2), strict=False) + else: + self.D1.load_state_dict(torch.load(ckpt_file_d1, map_location='cpu'), strict=False) + self.D2.load_state_dict(torch.load(ckpt_file_d2, map_location='cpu'), strict=False) + self.start_step = self.restore + 1 + else: + self.start_step = 1 def set_mode_and_gpu(self): if self.mode == 'train': @@ -49,6 +78,23 @@ def set_mode_and_gpu(self): self.D1.train() self.D2.train() + self.writer = SummaryWriter(self.config.log_dir) + + self.optimizer_G = torch.optim.Adam(chain(self.Enc.parameters(), self.Dec.parameters()), + lr=self.config.G_lr, betas=(0.5, 0.999), + weight_decay=self.config.weight_decay) + + self.optimizer_D = torch.optim.Adam(chain(self.D1.parameters(), self.D2.parameters()), + lr=self.config.D_lr, betas=(0.5, 0.999), + weight_decay=self.config.weight_decay) + + self.G_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer_G, step_size=self.config.step_size, gamma=self.config.gamma) + self.D_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer_D, step_size=self.config.step_size, gamma=self.config.gamma) + if self.restore is not None: + for _ in range(self.restore): + self.D_lr_scheduler.step() + self.D_lr_scheduler.step() + if self.gpu: with torch.cuda.device(0): self.Enc.cuda() @@ -80,29 +126,6 @@ def set_mode_and_gpu(self): else: raise NotImplementationError() - def restore_from_file(self): - if self.restore is not None: - ckpt_file_enc = os.path.join(self.config.model_dir, 'Enc_iter_{:06d}.pth'.format(self.restore)) - assert os.path.exists(ckpt_file_enc) - self.Enc.load_state_dict(torch.load(ckpt_file_enc)) - - ckpt_file_dec = os.path.join(self.config.model_dir, 'Dec_iter_{:06d}.pth'.format(self.restore)) - assert os.path.exists(ckpt_file_dec) - self.Dec.load_state_dict(torch.load(ckpt_file_dec)) - - if self.mode == 'train': - ckpt_file_d1 = os.path.join(self.config.model_dir, 'D1_iter_{:06d}.pth'.format(self.restore)) - assert os.path.exists(ckpt_file_d1) - self.D1.load_state_dict(torch.load(ckpt_file_d1)) - - ckpt_file_d2 = os.path.join(self.config.model_dir, 'D2_iter_{:06d}.pth'.format(self.restore)) - assert os.path.exists(ckpt_file_d2) - self.D2.load_state_dict(torch.load(ckpt_file_d2)) - - self.start_step = self.restore + 1 - else: - self.start_step = 1 - def tensor2var(self, tensors, volatile=False): if not hasattr(tensors, '__iter__'): tensors = [tensors] out = [] @@ -123,7 +146,6 @@ def get_attr_chs(self, encodings, attribute_id): end = int(np.rint(per_chs * (attribute_id + 1))) return encodings[:, start:end] - def forward_G(self): self.z_A, self.A_skip = self.Enc(self.A, return_skip=True) self.z_B, self.B_skip = self.Enc(self.B, return_skip=True) @@ -241,35 +263,25 @@ def save_scalar_log(self): self.writer.add_scalar(tag, value, self.step) def save_model(self): - torch.save({key: val.cpu() for key, val in self.Enc.state_dict().items()}, os.path.join(self.config.model_dir, 'Enc_iter_{:06d}.pth'.format(self.step))) - torch.save({key: val.cpu() for key, val in self.Dec.state_dict().items()}, os.path.join(self.config.model_dir, 'Dec_iter_{:06d}.pth'.format(self.step))) - torch.save({key: val.cpu() for key, val in self.D1.state_dict().items()}, os.path.join(self.config.model_dir, 'D1_iter_{:06d}.pth'.format(self.step))) - torch.save({key: val.cpu() for key, val in self.D2.state_dict().items()}, os.path.join(self.config.model_dir, 'D2_iter_{:06d}.pth'.format(self.step))) + def reduced(key): + if key.startswith('module.'): + return key[7:] + else: + return key + torch.save({reduced(key): val.cpu() for key, val in self.Enc.state_dict().items()}, os.path.join(self.config.model_dir, 'Enc_iter_{:06d}.pth'.format(self.step))) + torch.save({reduced(key): val.cpu() for key, val in self.Dec.state_dict().items()}, os.path.join(self.config.model_dir, 'Dec_iter_{:06d}.pth'.format(self.step))) + torch.save({reduced(key): val.cpu() for key, val in self.D1.state_dict().items()}, os.path.join(self.config.model_dir, 'D1_iter_{:06d}.pth'.format(self.step))) + torch.save({reduced(key): val.cpu() for key, val in self.D2.state_dict().items()}, os.path.join(self.config.model_dir, 'D2_iter_{:06d}.pth'.format(self.step))) def train(self): - self.writer = SummaryWriter(self.config.log_dir) - - self.optimizer_G = torch.optim.Adam(chain(self.Enc.parameters(), self.Dec.parameters()), - lr=self.config.G_lr, betas=(0.5, 0.999), - weight_decay=self.config.weight_decay) - - self.optimizer_D = torch.optim.Adam(chain(self.D1.parameters(), self.D2.parameters()), - lr=self.config.D_lr, betas=(0.5, 0.999), - weight_decay=self.config.weight_decay) - - self.G_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer_G, step_size=self.config.step_size, gamma=self.config.gamma) - self.D_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer_D, step_size=self.config.step_size, gamma=self.config.gamma) - - # start training - for step in range(self.start_step, 1 + self.config.max_iter): - self.step = step + for self.step in range(self.start_step, 1 + self.config.max_iter): + if self.step > self.start_step + 3: break; self.G_lr_scheduler.step() self.D_lr_scheduler.step() - for attribute_id in range(self.n_attributes): - self.attribute_id = attribute_id - A, y_A = next(self.dataset.gen(attribute_id, True)) - B, y_B = next(self.dataset.gen(attribute_id, False)) + for self.attribute_id in range(self.n_attributes): + A, y_A = next(self.dataset.gen(self.attribute_id, True)) + B, y_B = next(self.dataset.gen(self.attribute_id, False)) self.A, self.y_A, self.B, self.y_B = self.tensor2var([A, y_A, B, y_B]) # forward @@ -299,7 +311,8 @@ def train(self): if self.step % 100 == 0: self.save_scalar_log() - if self.step % 2000 == 0: + # if self.step % 2000 == 0: + if 1: self.save_model() print('Finished Training!') diff --git a/README.md b/README.md index 8df0bf4..9b8cea2 100644 --- a/README.md +++ b/README.md @@ -73,8 +73,9 @@ We provide four types of mode for testing. Let me explain all the parameters for - `-a`: All attributes' names. - `-r`: Restore checkpoint. -- `-g`: The gpu id(s) for testing. +- `-g`: The GPU id(s) for testing. - Don't add this parameter to your shell command if you don't want to use gpu for testing. + - No more than 1 GPU should be specified during test, because 1 image cannot be split into multiple GPUs. - `--swap`: Swap attribute of two images. - `--linear`: Linear interpolation by adding or removing one certain attribute. - `--matrix`: Matrix interpolation with respect to one or two attributes. diff --git a/datasets b/datasets new file mode 120000 index 0000000..c43550b --- /dev/null +++ b/datasets @@ -0,0 +1 @@ +/tmp5/taihong/datasets \ No newline at end of file