Skip to content

Commit

Permalink
Enable changes of hidden and context dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
hm-ysjiang committed Jun 7, 2023
1 parent 7d6c2b5 commit e6b59c4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions core/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def __init__(self, args):
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim, input_dim=cdim)

else:
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim, input_dim=cdim)

def freeze_bn(self):
for m in self.modules():
Expand Down
6 changes: 3 additions & 3 deletions core/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ def forward(self, flow, corr):
return torch.cat([out, flow], dim=1)

class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
def __init__(self, args, hidden_dim=96, input_dim=64):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+input_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)

def forward(self, net, inp, corr, flow):
Expand All @@ -116,7 +116,7 @@ def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+input_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)

self.mask = nn.Sequential(
Expand Down
2 changes: 2 additions & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ def validate_kitti(model, iters=24):

model = torch.nn.DataParallel(RAFT(args))
checkpoint = torch.load(args.model)
if 'epoch' in checkpoint:
print(f'Epoch {checkpoint["epoch"]}')
weight = checkpoint['model'] if 'model' in checkpoint else checkpoint
model.load_state_dict(weight)

Expand Down

0 comments on commit e6b59c4

Please sign in to comment.