forked from anonymousECCV2022/paper2031_ECCV2022_code
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 7b75784
Showing
1 changed file
with
161 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import os | ||
import math | ||
from decimal import Decimal | ||
|
||
import utility | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn.utils as utils | ||
from tqdm import tqdm | ||
import random | ||
import pdb | ||
|
||
class Trainer(): | ||
def __init__(self, args, loader, my_model, my_loss, ckp): | ||
self.args = args | ||
self.scale = args.scale | ||
|
||
self.ckp = ckp | ||
self.loader_train = loader.loader_train | ||
self.loader_test = loader.loader_test | ||
self.model = my_model | ||
self.loss = my_loss | ||
self.optimizer = utility.make_optimizer(args, self.model) | ||
|
||
if self.args.load != '': | ||
self.optimizer.load(ckp.dir, epoch=len(ckp.log)) | ||
|
||
self.error_last = 1e8 | ||
|
||
def train(self): | ||
self.loss.step() | ||
epoch = self.optimizer.get_last_epoch() + 1 | ||
lr = self.optimizer.get_lr() | ||
|
||
self.ckp.write_log( | ||
'[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) | ||
) | ||
self.loss.start_log() | ||
self.model.train() | ||
|
||
timer_data, timer_model = utility.timer(), utility.timer() | ||
# TEMP | ||
self.loader_train.dataset.set_scale(0) | ||
for batch, (lr, hr, _,) in enumerate(self.loader_train): | ||
|
||
k = random.randint(0,3) | ||
lr_r = torch.from_numpy(np.rot90(lr,-k,(2,3)).copy()) | ||
lr, hr, lr_r = self.prepare(lr, hr, lr_r) | ||
timer_data.hold() | ||
timer_model.tic() | ||
|
||
self.optimizer.zero_grad() | ||
sr = self.model(lr, 0) | ||
sr_r = self.model(lr_r, 0) | ||
|
||
# L2 loss gt, and prediction rotation | ||
sr_g = torch.from_numpy(np.rot90(sr.detach().cpu().numpy(), -k, (2,3)).copy()).cuda() | ||
sr_r_g = torch.from_numpy(np.rot90(sr_r.detach().cpu().numpy(), k, (2,3)).copy()).cuda() | ||
|
||
#change the backword line | ||
#loss = self.loss(sr, hr, sr_r_g) | ||
loss = self.loss(sr, hr, sr_g, sr_r, sr_r_g) | ||
|
||
loss.backward() | ||
if self.args.gclip > 0: | ||
utils.clip_grad_value_( | ||
self.model.parameters(), | ||
self.args.gclip | ||
) | ||
self.optimizer.step() | ||
|
||
timer_model.hold() | ||
|
||
if (batch + 1) % self.args.print_every == 0: | ||
self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( | ||
(batch + 1) * self.args.batch_size, | ||
len(self.loader_train.dataset), | ||
self.loss.display_loss(batch), | ||
timer_model.release(), | ||
timer_data.release())) | ||
|
||
timer_data.tic() | ||
|
||
self.loss.end_log(len(self.loader_train)) | ||
self.error_last = self.loss.log[-1, -1] | ||
self.optimizer.schedule() | ||
|
||
def test(self): | ||
torch.set_grad_enabled(False) | ||
|
||
epoch = self.optimizer.get_last_epoch() | ||
self.ckp.write_log('\nEvaluation:') | ||
self.ckp.add_log( | ||
torch.zeros(1, len(self.loader_test), len(self.scale)) | ||
) | ||
self.model.eval() | ||
|
||
timer_test = utility.timer() | ||
if self.args.save_results: self.ckp.begin_background() | ||
for idx_data, d in enumerate(self.loader_test): | ||
for idx_scale, scale in enumerate(self.scale): | ||
d.dataset.set_scale(idx_scale) | ||
for lr, hr, filename in tqdm(d, ncols=80): | ||
lr, hr = self.prepare(lr, hr) | ||
sr = self.model(lr, idx_scale) | ||
sr = utility.quantize(sr, self.args.rgb_range) | ||
|
||
save_list = [sr] | ||
self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( | ||
sr, hr, scale, self.args.rgb_range, dataset=d | ||
) | ||
if self.args.save_gt: | ||
save_list.extend([lr, hr]) | ||
|
||
if self.args.save_results: | ||
self.ckp.save_results(d, filename[0], save_list, scale) | ||
|
||
self.ckp.log[-1, idx_data, idx_scale] /= len(d) | ||
best = self.ckp.log.max(0) | ||
self.ckp.write_log( | ||
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( | ||
d.dataset.name, | ||
scale, | ||
self.ckp.log[-1, idx_data, idx_scale], | ||
best[0][idx_data, idx_scale], | ||
best[1][idx_data, idx_scale] + 1 | ||
) | ||
) | ||
|
||
self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) | ||
self.ckp.write_log('Saving...') | ||
|
||
if self.args.save_results: | ||
self.ckp.end_background() | ||
|
||
if not self.args.test_only: | ||
self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) | ||
|
||
self.ckp.write_log( | ||
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True | ||
) | ||
|
||
torch.set_grad_enabled(True) | ||
|
||
def prepare(self, *args): | ||
device = torch.device('cpu' if self.args.cpu else 'cuda') | ||
def _prepare(tensor): | ||
if self.args.precision == 'half': tensor = tensor.half() | ||
return tensor.to(device) | ||
|
||
return [_prepare(a) for a in args] | ||
|
||
def terminate(self): | ||
if self.args.test_only: | ||
self.test() | ||
return True | ||
else: | ||
epoch = self.optimizer.get_last_epoch() + 1 | ||
return epoch >= self.args.epochs | ||
|