Skip to content

Commit

Permalink
Ablations
Browse files Browse the repository at this point in the history
  • Loading branch information
hm-ysjiang committed Jun 9, 2023
1 parent c298b86 commit 036b66a
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 35 deletions.
22 changes: 22 additions & 0 deletions ablation-globalmatching.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
supervised="python -u train-supervised.py \
--name supervised-transfer \
--num_epochs 50 \
--batch_size 4 \
--lr 0.000125 \
--wdecay 0.00001 \
--restore_ckpt checkpoints/raft-things.pth"

supervised_gm="python -u train-supervised.py \
--name supervised-transfer \
--num_epochs 50 \
--batch_size 4 \
--lr 0.000125 \
--wdecay 0.00001 \
--restore_ckpt checkpoints/raft-things.pth \
--global_matching"


cmd=$supervised # Change this line

echo ${cmd}
eval ${cmd}
8 changes: 4 additions & 4 deletions ablation-upsampling.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,24 @@ plus_l1="python -u train-supervised.py \
--batch_size 3 \
--lr 0.0004 \
--wdecay 0.00001 \
--wloss_l1recon 1.0"
--wloss_l1recon 2.5"

plus_ssim="python -u train-supervised.py \
--name upsampling-plusssim \
--num_epochs 100 \
--batch_size 3 \
--lr 0.0004 \
--wdecay 0.00001 \
--wloss_ssimrecon 1.0"
--wloss_ssimrecon 2.5"

full="python -u train-supervised.py \
--name upsampling-full \
--num_epochs 100 \
--batch_size 3 \
--lr 0.0004 \
--wdecay 0.00001 \
--wloss_l1recon 1.0 \
--wloss_ssimrecon 1.0"
--wloss_l1recon 2.5 \
--wloss_ssimrecon 2.5"

cmd=$baseline # Change this line

Expand Down
1 change: 1 addition & 0 deletions core/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
corr = CorrBlock.corr(fmap1, fmap2)

batch, h1, w1, dim, h2, w2 = corr.shape
self.corrMap = corr.view(batch, h1 * w1, h2 * w2) # GMFlowNet
corr = corr.reshape(batch*h1*w1, dim, h2, w2)

self.corr_pyramid.append(corr)
Expand Down
73 changes: 55 additions & 18 deletions core/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
class autocast:
def __init__(self, enabled):
pass

def __enter__(self):
pass

def __exit__(self, *args):
pass

Expand All @@ -31,7 +33,7 @@ def __init__(self, args):
self.context_dim = cdim = args.context // 2
args.corr_levels = 4
args.corr_radius = 3

else:
self.hidden_dim = hdim = args.hidden
self.context_dim = cdim = args.context
Expand All @@ -46,14 +48,20 @@ def __init__(self, args):

# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim, input_dim=cdim)
self.fnet = SmallEncoder(
output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(
output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(
self.args, hidden_dim=hdim, input_dim=cdim)

else:
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim, input_dim=cdim)
self.fnet = BasicEncoder(
output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(
output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(
self.args, hidden_dim=hdim, input_dim=cdim)

def freeze_bn(self):
for m in self.modules():
Expand All @@ -75,15 +83,14 @@ def upsample_flow(self, flow, mask):
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)

up_flow = F.unfold(8 * flow, [3,3], padding=1)
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)

up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8*H, 8*W)


def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
def forward(self, image1, image2, iters=12, flow_init=None, global_matching=False, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """

image1 = 2 * (image1 / 255.0) - 1.0
Expand All @@ -97,12 +104,14 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_

# run the feature network
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])

fmap1, fmap2 = self.fnet([image1, image2])

batch_size, _, fmap_height, fmap_width = fmap1.shape
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
corr_fn = AlternateCorrBlock(
fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)

Expand All @@ -114,18 +123,46 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_
inp = torch.relu(inp)

coords0, coords1 = self.initialize_flow(image1)
softCorrMap = F.softmax(corr_fn.corrMap, dim=2) \
* F.softmax(corr_fn.corrMap, dim=1)

if flow_init is not None:
coords1 = coords1 + flow_init
elif global_matching:
# GMFlowNet
match12, match_idx12 = softCorrMap.max(dim=2) # (N, fH*fW)
match21, match_idx21 = softCorrMap.max(dim=1)

for b_idx in range(batch_size):
match21_b = match21[b_idx, :]
match_idx12_b = match_idx12[b_idx, :]
match21[b_idx, :] = match21_b[match_idx12_b]

matched = (match12 - match21) == 0 # (N, fH*fW)
coords_index = torch.arange(fmap_height * fmap_width) \
.unsqueeze(0) \
.repeat(batch_size, 1) \
.to(softCorrMap.device)
coords_index[matched] = match_idx12[matched]

# matched coords
coords_index = coords_index.reshape(batch_size,
fmap_height, fmap_width)
coords_x = coords_index % fmap_width
coords_y = coords_index // fmap_width

coords_xy = torch.stack([coords_x, coords_y], dim=1).float()
coords1 = coords_xy

flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
corr = corr_fn(coords1) # index correlation volume

flow = coords1 - coords0
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
net, up_mask, delta_flow = self.update_block(
net, inp, corr, flow)

# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
Expand All @@ -135,10 +172,10 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)

flow_predictions.append(flow_up)

if test_mode:
return coords1 - coords0, flow_up
return flow_predictions

return flow_predictions, softCorrMap
39 changes: 39 additions & 0 deletions core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,42 @@ def photometric_error(img1: torch.Tensor, img2: torch.Tensor,
l1_err = (img1_warped * valid - img1 * valid).abs()
ssim_err = SSIM_error(img1_warped * valid, img1 * valid)
return l1_err.mean(), ssim_err.mean()


# GMFlowNet
@torch.no_grad()
def compute_supervision_match(flow, occlusions, scale: int):
N, _, H, W = flow.shape
Hc, Wc = int(np.ceil(H / scale)), int(np.ceil(W / scale))

occlusions_c = occlusions[:, :, ::scale, ::scale]
flow_c = flow[:, :, ::scale, ::scale] / scale
occlusions_c = occlusions_c.reshape(N, Hc * Wc)

grid_c = coords_grid(N, Hc, Wc,
device=flow.device).permute(0, 2, 3, 1).reshape(N, Hc * Wc, 2)
warp_c = grid_c + flow_c.permute(0, 2, 3, 1).reshape(N, Hc * Wc, 2)
warp_c = warp_c.round().long()

def out_bound_mask(pt, w, h):
return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)

occlusions_c[out_bound_mask(warp_c, Wc, Hc)] = 1
warp_c = warp_c[..., 0] + warp_c[..., 1] * Wc

b_ids, i_ids = torch.split(torch.nonzero(occlusions_c == 0), 1, dim=1)
conf_matrix_gt = torch.zeros(N, Hc * Wc, Hc * Wc, device=flow.device)
j_ids = warp_c[b_ids, i_ids]
conf_matrix_gt[b_ids, i_ids, j_ids] = 1

return conf_matrix_gt

# GMFlowNet
def compute_match_loss(conf, conf_gt):
pos_mask, neg_mask = conf_gt == 1, conf_gt == 0

conf = torch.clamp(conf, 1e-6, 1 - 1e-6)
loss_pos = -torch.log(conf[pos_mask])
loss_neg = -torch.log(1 - conf[neg_mask])

return loss_pos.mean() + loss_neg.mean()
6 changes: 3 additions & 3 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def validate_sintel(model, iters=32, dstypes: List[Literal['clean', 'final']] =
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)

flow_low, flow_pr = model(
image1, image2, iters=iters, test_mode=True)
flow_low, flow_pr = model(image1, image2,
iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).cpu()

epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
Expand Down Expand Up @@ -173,7 +173,7 @@ def validate_kitti(model, iters=24):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--dstype', default='clean',
parser.add_argument('--dstype', default='clean',
choices=['clean', 'final', 'mixed'],
help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
Expand Down
14 changes: 11 additions & 3 deletions train-selfsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def train(args):

for epoch in range(epoch_start, args.num_epochs):
logger.initPbar(len(train_loader), epoch + 1)
should_global_matching = args.global_matching and epoch > 0
for batch_idx, data_blob in enumerate(train_loader):
optimizer.zero_grad()
image1, image2, flow, valid = [x.cuda() for x in data_blob]
Expand All @@ -157,7 +158,9 @@ def train(args):
image2 = (image2 + stdv * torch.randn(*
image2.shape).cuda()).clamp(0.0, 255.0)

flow_predictions = model(image1, image2, iters=args.iters)
flow_predictions, softCorrMap = model(image1, image2,
iters=args.iters,
global_matching=should_global_matching)

loss, metrics = sequence_loss(flow_predictions, flow,
image1, image2, valid,
Expand Down Expand Up @@ -239,8 +242,10 @@ def train(args):
parser.add_argument('--image_size', type=int,
nargs='+', default=[368, 768])
parser.add_argument('--gpus', type=int, nargs='+', default=[0])
parser.add_argument('--mixed_precision',
action='store_true', help='use mixed precision')
parser.add_argument('--mixed_precision', action='store_true',
help='use mixed precision')
parser.add_argument('--global_matching', action='store_true',
help='use global matching before optimization')

parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--wdecay', type=float, default=.00005)
Expand All @@ -261,6 +266,9 @@ def train(args):
if args.hidden != 128 or args.context != 128:
args.reset_context = True
args.name = f'{args.name}-{args.dstype}-ep{args.num_epochs}-c{args.context}'
if args.global_matching:
args.name = f'{args.name}-gm'
print(args)

torch.manual_seed(1234)
np.random.seed(1234)
Expand Down
26 changes: 19 additions & 7 deletions train-supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.optim as optim
from logger import Logger
from raft import RAFT
from utils.utils import photometric_error
from utils.utils import compute_match_loss, compute_supervision_match, photometric_error

import datasets
import evaluate
Expand Down Expand Up @@ -52,12 +52,13 @@ def update(self):

def sequence_loss(flow_preds: List[torch.Tensor], flow_gt: torch.Tensor,
image1: torch.Tensor, image2: torch.Tensor,
valid: torch.Tensor, args, max_flow=MAX_FLOW):
valid: torch.Tensor, softCorrMap: torch.Tensor,
args, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """

n_predictions = len(flow_preds)
gamma = args.gamma
should_recon = args.wloss_l1recon > 0 and args.wloss_ssimrecon > 0
should_recon = args.wloss_l1recon > 0 or args.wloss_ssimrecon > 0
flow_loss: torch.Tensor = 0.0

# exlude invalid pixels and extremely large diplacements
Expand All @@ -75,6 +76,10 @@ def sequence_loss(flow_preds: List[torch.Tensor], flow_gt: torch.Tensor,
photo_loss = (1 - SSIM_WEIGHT) * l1_err * args.wloss_l1recon \
+ SSIM_WEIGHT * ssim_err * args.wloss_ssimrecon
flow_loss += photo_loss
if args.global_matching:
occlusion = 1.0 - valid.float()
gt_match = compute_supervision_match(flow_gt, occlusion[:, None], 8)
flow_loss += 0.01 * compute_match_loss(softCorrMap, gt_match)

epe = torch.sum((flow_preds[-1] - flow_gt)**2, dim=1).sqrt()
epe = epe.view(-1)[valid.view(-1)]
Expand Down Expand Up @@ -152,6 +157,7 @@ def train(args):

for epoch in range(epoch_start, args.num_epochs):
logger.initPbar(len(train_loader), epoch + 1)
should_global_matching = args.global_matching and epoch > 0
for batch_idx, data_blob in enumerate(train_loader):
optimizer.zero_grad()
image1, image2, flow, valid = [x.cuda() for x in data_blob]
Expand All @@ -163,11 +169,13 @@ def train(args):
image2 = (image2 + stdv * torch.randn(*
image2.shape).cuda()).clamp(0.0, 255.0)

flow_predictions = model(image1, image2, iters=args.iters)
flow_predictions, softCorrMap = model(image1, image2,
iters=args.iters,
global_matching=should_global_matching)

loss, metrics = sequence_loss(flow_predictions, flow,
image1, image2, valid,
args)
softCorrMap, args)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
Expand Down Expand Up @@ -245,8 +253,10 @@ def train(args):
parser.add_argument('--image_size', type=int,
nargs='+', default=[368, 768])
parser.add_argument('--gpus', type=int, nargs='+', default=[0])
parser.add_argument('--mixed_precision',
action='store_true', help='use mixed precision')
parser.add_argument('--mixed_precision', action='store_true',
help='use mixed precision')
parser.add_argument('--global_matching', action='store_true',
help='use global matching before optimization')

parser.add_argument('--iters', type=int, default=12)
parser.add_argument('--wdecay', type=float, default=.00005)
Expand All @@ -271,6 +281,8 @@ def train(args):
if args.hidden != 128 or args.context != 128:
args.reset_context = True
args.name = f'{args.name}-{args.dstype}-ep{args.num_epochs}-c{args.context}'
if args.global_matching:
args.name = f'{args.name}-gm'
print(args)

torch.manual_seed(1234)
Expand Down

0 comments on commit 036b66a

Please sign in to comment.