Skip to content

Commit

Permalink
Self-supervised training
Browse files Browse the repository at this point in the history
  • Loading branch information
hm-ysjiang committed Jun 3, 2023
1 parent 122087d commit 788de8d
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 17 deletions.
88 changes: 81 additions & 7 deletions core/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
import numpy as np
import torch
import torch.autograd as autograd
import torch.nn.functional as F
import numpy as np
from scipy import interpolate


class InputPadder:
""" Pads images such that dimensions are divisible by 8 """

def __init__(self, dims, mode='sintel'):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == 'sintel':
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
self._pad = [pad_wd//2, pad_wd - pad_wd//2,
pad_ht//2, pad_ht - pad_ht//2]
else:
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]

def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs]

def unpad(self,x):
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]


def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
Expand All @@ -32,7 +36,7 @@ def forward_interpolate(flow):

x1 = x0 + dx
y1 = y0 + dy

x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
Expand All @@ -57,7 +61,7 @@ def forward_interpolate(flow):
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1

Expand All @@ -72,11 +76,81 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False):


def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
coords = torch.meshgrid(torch.arange(ht, device=device),
torch.arange(wd, device=device))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)


def upflow8(flow, mode='bilinear'):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)


def create_flow_grid(flow):
B, C, H, W = flow.size()
# mesh grid
xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
grid = torch.cat((xx, yy), 1).float()

if flow.is_cuda:
grid = grid.to(flow.get_device())
vgrid = grid + flow

# scale grid to [-1,1]
vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0

return vgrid.permute(0, 2, 3, 1)


def warp_flow(x, flow, use_mask=False):
"""
warp an image/tensor (im2) back to im1, according to the optical flow
Inputs:
x: [B, C, H, W] (im2)
flow: [B, 2, H, W] flow
Returns:
ouptut: [B, C, H, W]
"""
vgrid = create_flow_grid(flow)
output = F.grid_sample(x, vgrid, align_corners=True)
if use_mask:
mask = autograd.Variable(torch.ones(x.size())).to(x.get_device())
mask = F.grid_sample(mask, vgrid, align_corners=True)
mask[mask < 0.9999] = 0
mask[mask > 0] = 1
output = output * mask

return output


def SSIM_error(x, y):
C1 = 0.01 ** 2
C2 = 0.03 ** 2

mu_x = F.avg_pool2d(x, 3, 1, 0)
mu_y = F.avg_pool2d(y, 3, 1, 0)

# (input, kernel, stride, padding)
sigma_x = F.avg_pool2d(x ** 2, 3, 1, 0) - mu_x ** 2
sigma_y = F.avg_pool2d(y ** 2, 3, 1, 0) - mu_y ** 2
sigma_xy = F.avg_pool2d(x * y, 3, 1, 0) - mu_x * mu_y

SSIM_n = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)
SSIM_d = (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x + sigma_y + C2)

SSIM = SSIM_n / SSIM_d

return torch.clamp((1 - SSIM) / 2, 0, 1)


def photometric_error(img1: torch.Tensor, img2: torch.Tensor,
flow: torch.Tensor, valid: torch.Tensor):
img1_warped = warp_flow(img2, flow)
l1_err = (img1_warped * valid - img1 * valid).abs()
ssim_err = SSIM_error(img1_warped * valid, img1 * valid)
return l1_err.mean(), ssim_err.mean()
225 changes: 225 additions & 0 deletions train-selfsupervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
from __future__ import division, print_function

import sys # nopep8
from pathlib import Path # nopep8

sys.path.append('core') # nopep8

import argparse
import os
import time
from typing import List

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn_backend
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from logger import Logger
from raft import RAFT
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils.utils import photometric_error

import datasets
import evaluate

try:
from torch.cuda.amp import GradScaler
except:
# dummy GradScaler for PyTorch < 1.6
class GradScaler:
def __init__(self):
pass

def scale(self, loss):
return loss

def unscale_(self, optimizer):
pass

def step(self, optimizer):
optimizer.step()

def update(self):
pass


# exclude extremly large displacements
MAX_FLOW = 400
SUM_FREQ = 100
VAL_FREQ = 5000
SSIM_WEIGHT = 0.84


def sequence_loss(flow_preds: List[torch.Tensor], flow_gt: torch.Tensor,
image1: torch.Tensor, image2: torch.Tensor,
valid: torch.Tensor, gamma=0.8, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """

n_predictions = len(flow_preds)
flow_loss: torch.Tensor = 0.0

# exlude invalid pixels and extremely large diplacements
mag = torch.sum(flow_gt**2, dim=1).sqrt()
valid = (valid >= 0.5) & (mag < max_flow)

for i in range(n_predictions):
i_weight = gamma**(n_predictions - i - 1)
l1_err, ssim_err = photometric_error(image1, image2, flow_preds[i], valid[:, None])
i_loss = (1 - SSIM_WEIGHT) * l1_err + SSIM_WEIGHT * ssim_err
flow_loss += i_weight * i_loss

epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]

metrics = {
'epe': epe.mean().item(),
'1px': (epe < 1).float().mean().item(),
'3px': (epe < 3).float().mean().item(),
'5px': (epe < 5).float().mean().item(),
}

return flow_loss, metrics


def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)


def fetch_optimizer(args, model, steps):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr,
weight_decay=args.wdecay, eps=args.epsilon)

scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, steps+100,
pct_start=0.05, cycle_momentum=False, anneal_strategy='linear')

return optimizer, scheduler


def train(args):
model = nn.DataParallel(RAFT(args), device_ids=args.gpus)
print("Parameter Count: %d" % count_parameters(model))

if args.restore_ckpt is not None:
model.load_state_dict(torch.load(args.restore_ckpt),
strict=(not args.allow_nonstrict))

model.cuda()
model.train()

if args.freeze_bn:
model.module.freeze_bn()

train_loader = datasets.fetch_dataloader(args)
optimizer, scheduler = fetch_optimizer(args, model,
len(train_loader) * args.num_epochs)

scaler = GradScaler(enabled=args.mixed_precision)
logger = Logger(args.name)

VAL_FREQ = 5000
add_noise = True
best_evaluation = None

for epoch in range(args.num_epochs):
logger.initPbar(len(train_loader), epoch + 1)
for batch_idx, data_blob in enumerate(train_loader):
optimizer.zero_grad()
image1, image2, flow, valid = [x.cuda() for x in data_blob]

if args.add_noise:
stdv = np.random.uniform(0.0, 5.0)
image1 = (image1 + stdv * torch.randn(*
image1.shape).cuda()).clamp(0.0, 255.0)
image2 = (image2 + stdv * torch.randn(*
image2.shape).cuda()).clamp(0.0, 255.0)

flow_predictions = model(image1, image2, iters=args.iters)

loss, metrics = sequence_loss(flow_predictions, flow,
image1, image2, valid,
args.gamma)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

scaler.step(optimizer)
scheduler.step()
scaler.update()

logger.push({'loss': loss.item()})

logger.closePbar()
PATH = 'checkpoints/%s/model.pth' % args.name
torch.save(model.state_dict(), PATH)

results = {}
for val_dataset in args.validation:
if val_dataset == 'chairs':
results.update(evaluate.validate_chairs(model.module))
elif val_dataset == 'sintel':
results.update(evaluate.validate_sintel(model.module))
elif val_dataset == 'kitti':
results.update(evaluate.validate_kitti(model.module))
logger.write_dict(results, 'epoch')

evaluation_score = np.mean(list(results.values()))
if best_evaluation is None or evaluation_score < best_evaluation:
best_evaluation = evaluation_score
PATH = 'checkpoints/%s/model-best.pth' % args.name
torch.save(model.state_dict(), PATH)

model.train()
if args.freeze_bn:
model.module.freeze_bn()

logger.close()

return best_evaluation


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='raft', help="name your experiment")
parser.add_argument('--freeze_bn', action='store_true',
help="freeze the batch norm layer")
parser.add_argument('--restore_ckpt', help="restore checkpoint")
parser.add_argument('--allow_nonstrict', action='store_true',
help='allow non-strict loading')
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--validation', type=str, nargs='+')

parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--num_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=6)
parser.add_argument('--image_size', type=int,
nargs='+', default=[368, 768])
parser.add_argument('--gpus', type=int, nargs='+', default=[0])
parser.add_argument('--mixed_precision',
action='store_true', help='use mixed precision')

parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--wdecay', type=float, default=.00005)
parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--gamma', type=float, default=0.8,
help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')
args = parser.parse_args()

torch.manual_seed(1234)
np.random.seed(1234)

cudnn_backend.benchmark = True

os.makedirs(Path(__file__).parent.joinpath('checkpoints', args.name),
exist_ok=True)

train(args)
27 changes: 27 additions & 0 deletions train-selfsupervised.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash
mkdir -p checkpoints

cmd_scratch="python -u train-selfsupervised.py \
--name raft-sintel-selfsupervised-scratch \
--validation sintel \
--num_epochs 100 \
--batch_size 6 \
--lr 0.0004 \
--wdecay 0.00001"

cmd_transfer="python -u train-selfsupervised.py \
--name raft-sintel-selfsupervised-transfer \
--validation sintel \
--restore_ckpt checkpoints/raft-things.pth \
--freeze_bn \
--num_epochs 100 \
--batch_size 6 \
--lr 0.000125 \
--wdecay 0.00001 \
--gamma=0.85"

# echo ${cmd_scratch}
# eval ${cmd_scratch}

echo ${cmd_transfer}
eval ${cmd_transfer}
Loading

0 comments on commit 788de8d

Please sign in to comment.