Skip to content

Commit

Permalink
fixed bug with alternate_corr flag
Browse files Browse the repository at this point in the history
  • Loading branch information
Zach Teed committed Aug 28, 2020
1 parent 01ad964 commit 13198c3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 31 deletions.
30 changes: 5 additions & 25 deletions core/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,6 @@ def corr(fmap1, fmap2):
return corr / torch.sqrt(torch.tensor(dim).float())


class CorrLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, fmap1, fmap2, coords, r):
fmap1 = fmap1.contiguous()
fmap2 = fmap2.contiguous()
coords = coords.contiguous()
ctx.save_for_backward(fmap1, fmap2, coords)
ctx.r = r
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
return corr

@staticmethod
def backward(ctx, grad_corr):
fmap1, fmap2, coords = ctx.saved_tensors
grad_corr = grad_corr.contiguous()
fmap1_grad, fmap2_grad, coords_grad = \
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
return fmap1_grad, fmap2_grad, coords_grad, None


class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
Expand All @@ -92,20 +72,20 @@ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.pyramid.append((fmap1, fmap2))

def __call__(self, coords):

coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]

corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()

coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))

corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / 16.0
return corr / torch.sqrt(torch.tensor(dim).float())
11 changes: 5 additions & 6 deletions core/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def __init__(self, args):
args.corr_levels = 4
args.corr_radius = 4

if 'dropout' not in args._get_kwargs():
args.dropout = 0
if 'dropout' not in self.args:
self.args.dropout = 0

if 'alternate_corr' not in args._get_kwargs():
args.alternate_corr = False
if 'alternate_corr' not in self.args:
self.args.alternate_corr = False

# feature network, context network, and update block
if args.small:
Expand All @@ -55,7 +55,6 @@ def __init__(self, args):
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)


def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
Expand Down Expand Up @@ -103,7 +102,7 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)

Expand Down

0 comments on commit 13198c3

Please sign in to comment.