diff --git a/core/utils/utils.py b/core/utils/utils.py index 741ccfe..7e90f4e 100644 --- a/core/utils/utils.py +++ b/core/utils/utils.py @@ -1,28 +1,32 @@ +import numpy as np import torch +import torch.autograd as autograd import torch.nn.functional as F -import numpy as np from scipy import interpolate class InputPadder: """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): self.ht, self.wd = dims[-2:] pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 if mode == 'sintel': - self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + self._pad = [pad_wd//2, pad_wd - pad_wd//2, + pad_ht//2, pad_ht - pad_ht//2] else: self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] def pad(self, *inputs): return [F.pad(x, self._pad, mode='replicate') for x in inputs] - def unpad(self,x): + def unpad(self, x): ht, wd = x.shape[-2:] c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] return x[..., c[0]:c[1], c[2]:c[3]] + def forward_interpolate(flow): flow = flow.detach().cpu().numpy() dx, dy = flow[0], flow[1] @@ -32,7 +36,7 @@ def forward_interpolate(flow): x1 = x0 + dx y1 = y0 + dy - + x1 = x1.reshape(-1) y1 = y1.reshape(-1) dx = dx.reshape(-1) @@ -57,7 +61,7 @@ def forward_interpolate(flow): def bilinear_sampler(img, coords, mode='bilinear', mask=False): """ Wrapper for grid_sample, uses pixel coordinates """ H, W = img.shape[-2:] - xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid, ygrid = coords.split([1, 1], dim=-1) xgrid = 2*xgrid/(W-1) - 1 ygrid = 2*ygrid/(H-1) - 1 @@ -72,11 +76,81 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False): def coords_grid(batch, ht, wd, device): - coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.meshgrid(torch.arange(ht, device=device), + torch.arange(wd, device=device)) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) def upflow8(flow, mode='bilinear'): new_size = (8 * flow.shape[2], 8 * flow.shape[3]) - return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + +def create_flow_grid(flow): + B, C, H, W = flow.size() + # mesh grid + xx = torch.arange(0, W).view(1, -1).repeat(H, 1) + yy = torch.arange(0, H).view(-1, 1).repeat(1, W) + xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) + yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) + grid = torch.cat((xx, yy), 1).float() + + if flow.is_cuda: + grid = grid.to(flow.get_device()) + vgrid = grid + flow + + # scale grid to [-1,1] + vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 + vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 + + return vgrid.permute(0, 2, 3, 1) + + +def warp_flow(x, flow, use_mask=False): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + Inputs: + x: [B, C, H, W] (im2) + flow: [B, 2, H, W] flow + Returns: + ouptut: [B, C, H, W] + """ + vgrid = create_flow_grid(flow) + output = F.grid_sample(x, vgrid, align_corners=True) + if use_mask: + mask = autograd.Variable(torch.ones(x.size())).to(x.get_device()) + mask = F.grid_sample(mask, vgrid, align_corners=True) + mask[mask < 0.9999] = 0 + mask[mask > 0] = 1 + output = output * mask + + return output + + +def SSIM_error(x, y): + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + mu_x = F.avg_pool2d(x, 3, 1, 0) + mu_y = F.avg_pool2d(y, 3, 1, 0) + + # (input, kernel, stride, padding) + sigma_x = F.avg_pool2d(x ** 2, 3, 1, 0) - mu_x ** 2 + sigma_y = F.avg_pool2d(y ** 2, 3, 1, 0) - mu_y ** 2 + sigma_xy = F.avg_pool2d(x * y, 3, 1, 0) - mu_x * mu_y + + SSIM_n = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2) + SSIM_d = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2) + + SSIM = SSIM_n / SSIM_d + + return torch.clamp((1 - SSIM) / 2, 0, 1) + + +def photometric_error(img1: torch.Tensor, img2: torch.Tensor, + flow: torch.Tensor, valid: torch.Tensor): + img1_warped = warp_flow(img2, flow) + l1_err = (img1_warped * valid - img1 * valid).abs() + ssim_err = SSIM_error(img1_warped * valid, img1 * valid) + return l1_err.mean(), ssim_err.mean() diff --git a/train-selfsupervised.py b/train-selfsupervised.py new file mode 100644 index 0000000..625159f --- /dev/null +++ b/train-selfsupervised.py @@ -0,0 +1,225 @@ +from __future__ import division, print_function + +import sys # nopep8 +from pathlib import Path # nopep8 + +sys.path.append('core') # nopep8 + +import argparse +import os +import time +from typing import List + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.backends.cudnn as cudnn_backend +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from logger import Logger +from raft import RAFT +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm +from utils.utils import photometric_error + +import datasets +import evaluate + +try: + from torch.cuda.amp import GradScaler +except: + # dummy GradScaler for PyTorch < 1.6 + class GradScaler: + def __init__(self): + pass + + def scale(self, loss): + return loss + + def unscale_(self, optimizer): + pass + + def step(self, optimizer): + optimizer.step() + + def update(self): + pass + + +# exclude extremly large displacements +MAX_FLOW = 400 +SUM_FREQ = 100 +VAL_FREQ = 5000 +SSIM_WEIGHT = 0.84 + + +def sequence_loss(flow_preds: List[torch.Tensor], flow_gt: torch.Tensor, + image1: torch.Tensor, image2: torch.Tensor, + valid: torch.Tensor, gamma=0.8, max_flow=MAX_FLOW): + """ Loss function defined over sequence of flow predictions """ + + n_predictions = len(flow_preds) + flow_loss: torch.Tensor = 0.0 + + # exlude invalid pixels and extremely large diplacements + mag = torch.sum(flow_gt**2, dim=1).sqrt() + valid = (valid >= 0.5) & (mag < max_flow) + + for i in range(n_predictions): + i_weight = gamma**(n_predictions - i - 1) + l1_err, ssim_err = photometric_error(image1, image2, flow_preds[i], valid[:, None]) + i_loss = (1 - SSIM_WEIGHT) * l1_err + SSIM_WEIGHT * ssim_err + flow_loss += i_weight * i_loss + + epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt() + epe = epe.view(-1)[valid.view(-1)] + + metrics = { + 'epe': epe.mean().item(), + '1px': (epe < 1).float().mean().item(), + '3px': (epe < 3).float().mean().item(), + '5px': (epe < 5).float().mean().item(), + } + + return flow_loss, metrics + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def fetch_optimizer(args, model, steps): + """ Create the optimizer and learning rate scheduler """ + optimizer = optim.AdamW(model.parameters(), lr=args.lr, + weight_decay=args.wdecay, eps=args.epsilon) + + scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, steps+100, + pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') + + return optimizer, scheduler + + +def train(args): + model = nn.DataParallel(RAFT(args), device_ids=args.gpus) + print("Parameter Count: %d" % count_parameters(model)) + + if args.restore_ckpt is not None: + model.load_state_dict(torch.load(args.restore_ckpt), + strict=(not args.allow_nonstrict)) + + model.cuda() + model.train() + + if args.freeze_bn: + model.module.freeze_bn() + + train_loader = datasets.fetch_dataloader(args) + optimizer, scheduler = fetch_optimizer(args, model, + len(train_loader) * args.num_epochs) + + scaler = GradScaler(enabled=args.mixed_precision) + logger = Logger(args.name) + + VAL_FREQ = 5000 + add_noise = True + best_evaluation = None + + for epoch in range(args.num_epochs): + logger.initPbar(len(train_loader), epoch + 1) + for batch_idx, data_blob in enumerate(train_loader): + optimizer.zero_grad() + image1, image2, flow, valid = [x.cuda() for x in data_blob] + + if args.add_noise: + stdv = np.random.uniform(0.0, 5.0) + image1 = (image1 + stdv * torch.randn(* + image1.shape).cuda()).clamp(0.0, 255.0) + image2 = (image2 + stdv * torch.randn(* + image2.shape).cuda()).clamp(0.0, 255.0) + + flow_predictions = model(image1, image2, iters=args.iters) + + loss, metrics = sequence_loss(flow_predictions, flow, + image1, image2, valid, + args.gamma) + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) + + scaler.step(optimizer) + scheduler.step() + scaler.update() + + logger.push({'loss': loss.item()}) + + logger.closePbar() + PATH = 'checkpoints/%s/model.pth' % args.name + torch.save(model.state_dict(), PATH) + + results = {} + for val_dataset in args.validation: + if val_dataset == 'chairs': + results.update(evaluate.validate_chairs(model.module)) + elif val_dataset == 'sintel': + results.update(evaluate.validate_sintel(model.module)) + elif val_dataset == 'kitti': + results.update(evaluate.validate_kitti(model.module)) + logger.write_dict(results, 'epoch') + + evaluation_score = np.mean(list(results.values())) + if best_evaluation is None or evaluation_score < best_evaluation: + best_evaluation = evaluation_score + PATH = 'checkpoints/%s/model-best.pth' % args.name + torch.save(model.state_dict(), PATH) + + model.train() + if args.freeze_bn: + model.module.freeze_bn() + + logger.close() + + return best_evaluation + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--name', default='raft', help="name your experiment") + parser.add_argument('--freeze_bn', action='store_true', + help="freeze the batch norm layer") + parser.add_argument('--restore_ckpt', help="restore checkpoint") + parser.add_argument('--allow_nonstrict', action='store_true', + help='allow non-strict loading') + parser.add_argument('--small', action='store_true', help='use small model') + parser.add_argument('--validation', type=str, nargs='+') + + parser.add_argument('--lr', type=float, default=0.00002) + parser.add_argument('--num_epochs', type=int, default=10) + parser.add_argument('--batch_size', type=int, default=6) + parser.add_argument('--image_size', type=int, + nargs='+', default=[368, 768]) + parser.add_argument('--gpus', type=int, nargs='+', default=[0]) + parser.add_argument('--mixed_precision', + action='store_true', help='use mixed precision') + + parser.add_argument('--iters', type=int, default=12) + parser.add_argument('--wdecay', type=float, default=.00005) + parser.add_argument('--epsilon', type=float, default=1e-8) + parser.add_argument('--clip', type=float, default=1.0) + parser.add_argument('--dropout', type=float, default=0.0) + parser.add_argument('--gamma', type=float, default=0.8, + help='exponential weighting') + parser.add_argument('--add_noise', action='store_true') + args = parser.parse_args() + + torch.manual_seed(1234) + np.random.seed(1234) + + cudnn_backend.benchmark = True + + os.makedirs(Path(__file__).parent.joinpath('checkpoints', args.name), + exist_ok=True) + + train(args) diff --git a/train-selfsupervised.sh b/train-selfsupervised.sh new file mode 100644 index 0000000..4af90e6 --- /dev/null +++ b/train-selfsupervised.sh @@ -0,0 +1,27 @@ +#!/bin/bash +mkdir -p checkpoints + +cmd_scratch="python -u train-selfsupervised.py \ + --name raft-sintel-selfsupervised-scratch \ + --validation sintel \ + --num_epochs 100 \ + --batch_size 6 \ + --lr 0.0004 \ + --wdecay 0.00001" + +cmd_transfer="python -u train-selfsupervised.py \ + --name raft-sintel-selfsupervised-transfer \ + --validation sintel \ + --restore_ckpt checkpoints/raft-things.pth \ + --freeze_bn \ + --num_epochs 100 \ + --batch_size 6 \ + --lr 0.000125 \ + --wdecay 0.00001 \ + --gamma=0.85" + +# echo ${cmd_scratch} +# eval ${cmd_scratch} + +echo ${cmd_transfer} +eval ${cmd_transfer} diff --git a/train-supervised.py b/train-supervised.py index 1850083..dc636ca 100644 --- a/train-supervised.py +++ b/train-supervised.py @@ -150,6 +150,8 @@ def train(args): logger.push({'loss': loss.item()}) logger.closePbar() + PATH = 'checkpoints/%s/model.pth' % args.name + torch.save(model.state_dict(), PATH) results = {} for val_dataset in args.validation: @@ -172,10 +174,8 @@ def train(args): model.module.freeze_bn() logger.close() - PATH = 'checkpoints/%s/model.pth' % args.name - torch.save(model.state_dict(), PATH) - return PATH + return best_evaluation if __name__ == '__main__': @@ -193,8 +193,8 @@ def train(args): parser.add_argument('--num_epochs', type=int, default=10) parser.add_argument('--batch_size', type=int, default=6) parser.add_argument('--image_size', type=int, - nargs='+', default=[384, 512]) - parser.add_argument('--gpus', type=int, nargs='+', default=[0, 1]) + nargs='+', default=[368, 768]) + parser.add_argument('--gpus', type=int, nargs='+', default=[0]) parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') @@ -213,6 +213,7 @@ def train(args): cudnn_backend.benchmark = True - os.makedirs(Path(__file__).parent.joinpath('checkpoints', args.name)) + os.makedirs(Path(__file__).parent.joinpath('checkpoints', args.name), + exist_ok=True) train(args) diff --git a/train-supervised.sh b/train-supervised.sh index 1bc30b2..a0bbb77 100644 --- a/train-supervised.sh +++ b/train-supervised.sh @@ -4,11 +4,9 @@ mkdir -p checkpoints cmd_scratch="python -u train-supervised.py \ --name raft-sintel-supervised-scratch \ --validation sintel \ - --gpus 0 \ --num_epochs 100 \ --batch_size 6 \ --lr 0.0004 \ - --image_size 368 768 \ --wdecay 0.00001" cmd_transfer="python -u train-supervised.py \ @@ -16,11 +14,9 @@ cmd_transfer="python -u train-supervised.py \ --validation sintel \ --restore_ckpt checkpoints/raft-things.pth \ --freeze_bn \ - --gpus 0 \ --num_epochs 100 \ --batch_size 6 \ --lr 0.000125 \ - --image_size 368 768 \ --wdecay 0.00001 \ --gamma=0.85"