Skip to content

Commit

Permalink
Train custom NYCU dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
hm-ysjiang committed Jun 12, 2023
1 parent 036b66a commit 663a70c
Show file tree
Hide file tree
Showing 8 changed files with 629 additions and 90 deletions.
10 changes: 9 additions & 1 deletion ablation-globalmatching.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ supervised_gm="python -u train-supervised.py \
--restore_ckpt checkpoints/raft-things.pth \
--global_matching"

supervised_gm="python -u train-supervised.py \
--name supervised-scratch \
--num_epochs 200 \
--batch_size 4 \
--lr 0.0004 \
--wdecay 0.00001 \
--global_matching"


cmd=$supervised # Change this line
cmd=$supervised_gm # Change this line

echo ${cmd}
eval ${cmd}
2 changes: 1 addition & 1 deletion ablation-upsampling.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ full="python -u train-supervised.py \
--wloss_l1recon 2.5 \
--wloss_ssimrecon 2.5"

cmd=$baseline # Change this line
cmd=$plus_ssim # Change this line

echo ${cmd}
eval ${cmd}
113 changes: 113 additions & 0 deletions compose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import sys

sys.path.append('core') # nopep8

import argparse
import os

import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from raft import RAFT
from tqdm import trange
from utils.flow_viz import flow_to_image
from utils.utils import InputPadder

DEVICE = 'cuda'


def to_tensor(x_np):
x = torch.from_numpy(x_np[:, :, [2, 1, 0]]).permute(2, 0, 1).float()
return x.to(DEVICE)[None]


def compose(args):
vid = cv2.VideoCapture(args.input)
if not vid.isOpened():
print('Cannot open video file!')
exit(1)

VIDPROP_FRAMES = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
VIDPROP_HEIGHT = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
VIDPROP_WIDTH = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
VIDPROP_FPS = vid.get(cv2.CAP_PROP_FPS)

frames_input = np.empty((VIDPROP_FRAMES, VIDPROP_HEIGHT,
VIDPROP_WIDTH, 3), np.uint8)
frames_output = np.empty((VIDPROP_FRAMES - 1, VIDPROP_HEIGHT,
VIDPROP_WIDTH * 2, 3), np.uint8)
for frame_idx in range(VIDPROP_FRAMES):
if not vid.isOpened():
print('Error while reading frames!')
exit(1)

frame_ok, frame = vid.read()
if not frame_ok:
print('Error while reading frames!')
exit(1)

frames_input[frame_idx] = frame
vid.release()
print('Read %d x %d, %d frames@%.2fFPS.' %
(VIDPROP_WIDTH, VIDPROP_HEIGHT, VIDPROP_FRAMES, VIDPROP_FPS))

model = nn.DataParallel(RAFT(args))
checkpoint = torch.load(args.model)
weight = checkpoint['model'] if 'model' in checkpoint else checkpoint
model.load_state_dict(weight)
model = model.module
model.to(DEVICE)
model.eval()

flow_init = None
with torch.no_grad():
for frame_idx in trange(VIDPROP_FRAMES - 1, ncols=120):
image1 = to_tensor(frames_input[frame_idx])
image2 = to_tensor(frames_input[frame_idx + 1])

padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)

flow_low, flow_up = model(
image1, image2, iters=20, test_mode=True, flow_init=flow_init)
if args.warmup:
flow_init = flow_low.detach()

flow = flow_up[0].cpu().permute(1, 2, 0).numpy()
flow_viz = flow_to_image(flow)

frames_output[frame_idx] = np.concatenate([frames_input[frame_idx],
flow_viz[:, :, [2, 1, 0]]], axis=1)

out = cv2.VideoWriter('visualization/composed.avi',
cv2.VideoWriter_fourcc(*'XVID'),
VIDPROP_FPS, (VIDPROP_WIDTH * 2, VIDPROP_HEIGHT))
for frame in frames_output:
out.write(frame)
out.release()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, required=True,
help='The input video file')
parser.add_argument('--model', type=str, required=True,
help='The model weight')
parser.add_argument('--warmup', action='store_true',
help='use warm-up mode')
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision',
action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true',
help='use efficent correlation implementation')
parser.add_argument('--hidden', type=int, default=128,
help='The hidden size of the updater')
parser.add_argument('--context', type=int, default=128,
help='The context size of the updater')
args = parser.parse_args()

os.makedirs('visualization', exist_ok=True)

compose(args)
41 changes: 41 additions & 0 deletions core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,38 @@ def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
return img1, img2, flow, valid.float()


class NYCUData(data.Dataset):
def __init__(self, root='datasets/NYCU_set'):
super().__init__()
image_root = osp.join(root)
self.image_list = []

for scene in os.listdir(image_root):
image_files = sorted(glob(osp.join(image_root, scene, '*.jpg')))
n_pairs = len(image_files) - 1
for i in range(n_pairs):
self.image_list += [[image_files[i], image_files[i+1]]]

def __len__(self):
return len(self.image_list)

def __getitem__(self, index) -> Tuple[torch.Tensor, torch.Tensor]:
index = index % len(self.image_list)
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])

img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)

img1 = img1[..., :3]
img2 = img2[..., :3]

img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()

return img1, img2


# class FlyingChairs(FlowDataset):
# def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
# super(FlyingChairs, self).__init__(aug_params)
Expand Down Expand Up @@ -345,3 +377,12 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):

print('Training with %d image pairs' % len(train_dataset))
return train_loader


def fetch_nycu(args):
train_dataset = NYCUData()
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=True, shuffle=True, num_workers=4, drop_last=True)

print('Training with %d image pairs' % len(train_dataset))
return train_loader
64 changes: 61 additions & 3 deletions core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,21 @@ def warp_flow(x, flow, use_mask=False):
ouptut: [B, C, H, W]
"""
vgrid = create_flow_grid(flow)
return warp_vgrid(x, vgrid, use_mask)


def warp_vgrid(x: torch.Tensor, vgrid: torch.Tensor, use_mask=False):
"""
warp an image/tensor (im2) back to im1, according to the optical flow
Inputs:
x: [B, C, H, W] (im2)
flow: [B, 2, H, W] flow
Returns:
ouptut: [B, C, H, W]
"""
output = F.grid_sample(x, vgrid, align_corners=True)
if use_mask:
mask = autograd.Variable(torch.ones(x.size())).to(x.get_device())
mask = torch.autograd.Variable(torch.ones(x.size())).to(x.device)
mask = F.grid_sample(mask, vgrid, align_corners=True)
mask[mask < 0.9999] = 0
mask[mask > 0] = 1
Expand Down Expand Up @@ -156,6 +168,15 @@ def photometric_error(img1: torch.Tensor, img2: torch.Tensor,
return l1_err.mean(), ssim_err.mean()


def photometric_error_masked(img1: torch.Tensor, img2: torch.Tensor,
vgrid: torch.Tensor, valid: torch.Tensor):
maskw = valid.mean() + 1e-6
img1_warped = warp_vgrid(img2, vgrid)
l1_err = (img1_warped * valid - img1 * valid).abs()
ssim_err = SSIM_error(img1_warped * valid, img1 * valid)
return l1_err.mean() / maskw, ssim_err.mean() / maskw


# GMFlowNet
@torch.no_grad()
def compute_supervision_match(flow, occlusions, scale: int):
Expand Down Expand Up @@ -184,12 +205,49 @@ def out_bound_mask(pt, w, h):

return conf_matrix_gt

# GMFlowNet
def compute_match_loss(conf, conf_gt):

def compute_match_loss(conf, conf_gt): # GMFlowNet
pos_mask, neg_mask = conf_gt == 1, conf_gt == 0

conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
loss_pos = -torch.log(conf[pos_mask])
loss_neg = -torch.log(1 - conf[neg_mask])

return loss_pos.mean() + loss_neg.mean()


def magsq(x: torch.Tensor, dim):
return torch.sum(x**2, dim, keepdim=(dim is not None))


def create_border_mask(tensor: torch.Tensor, ratio=0.1):
B, _, H, W = tensor.shape
sz = np.ceil(min(H, W) * ratio).astype(int).item(0)
border_mask = torch.zeros((H, W), dtype=tensor.dtype, device=tensor.device)
border_mask[sz:-sz, sz:-sz] = 1.0
border_mask = border_mask.view(1, 1, H, W).expand(B, -1, -1, -1)
return border_mask.detach()


def fwdbwd_occ_mask(flow_fwd: torch.Tensor, flow_bwd: torch.Tensor,
vgrid_fwd: torch.Tensor, vgrid_bwd: torch.Tensor,
use_border_mask=False, return_warpdiff=False):
mag_flow = magsq(flow_fwd, 1) + magsq(flow_bwd, 1)
flow_bwd_warped = warp_vgrid(flow_bwd, vgrid_fwd, True)
flow_fwd_warped = warp_vgrid(flow_fwd, vgrid_bwd, True)
flow_fwd_warpdiff = flow_fwd + flow_bwd_warped
flow_bwd_warpdiff = flow_bwd + flow_fwd_warped
occ_thresh = 0.01 * mag_flow + 0.5
occ_fwd = (magsq(flow_fwd_warpdiff, 1) > occ_thresh).float()
occ_bwd = (magsq(flow_bwd_warpdiff, 1) > occ_thresh).float()
mask_fwd = (1 - occ_fwd)
mask_bwd = (1 - occ_bwd)

if use_border_mask:
border_mask = create_border_mask(flow_fwd)
mask_fwd = border_mask * mask_fwd
mask_bwd = border_mask * mask_bwd

if return_warpdiff:
return mask_fwd, mask_bwd, flow_fwd_warpdiff, flow_bwd_warpdiff
return mask_fwd, mask_bwd
Loading

0 comments on commit 663a70c

Please sign in to comment.