From e6b59c488e8e693a6a25a4c233812bedc9f4b69c Mon Sep 17 00:00:00 2001 From: hm-ysjiang Date: Wed, 7 Jun 2023 13:45:36 +0800 Subject: [PATCH] Enable changes of hidden and context dimensions --- core/raft.py | 4 ++-- core/update.py | 6 +++--- evaluate.py | 2 ++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/core/raft.py b/core/raft.py index e9601f9..565df8f 100644 --- a/core/raft.py +++ b/core/raft.py @@ -48,12 +48,12 @@ def __init__(self, args): if args.small: self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) - self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim, input_dim=cdim) else: self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) - self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim, input_dim=cdim) def freeze_bn(self): for m in self.modules(): diff --git a/core/update.py b/core/update.py index f940497..e3ca492 100644 --- a/core/update.py +++ b/core/update.py @@ -97,10 +97,10 @@ def forward(self, flow, corr): return torch.cat([out, flow], dim=1) class SmallUpdateBlock(nn.Module): - def __init__(self, args, hidden_dim=96): + def __init__(self, args, hidden_dim=96, input_dim=64): super(SmallUpdateBlock, self).__init__() self.encoder = SmallMotionEncoder(args) - self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+input_dim) self.flow_head = FlowHead(hidden_dim, hidden_dim=128) def forward(self, net, inp, corr, flow): @@ -116,7 +116,7 @@ def __init__(self, args, hidden_dim=128, input_dim=128): super(BasicUpdateBlock, self).__init__() self.args = args self.encoder = BasicMotionEncoder(args) - self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+input_dim) self.flow_head = FlowHead(hidden_dim, hidden_dim=256) self.mask = nn.Sequential( diff --git a/evaluate.py b/evaluate.py index d74e550..dfbc1c5 100644 --- a/evaluate.py +++ b/evaluate.py @@ -186,6 +186,8 @@ def validate_kitti(model, iters=24): model = torch.nn.DataParallel(RAFT(args)) checkpoint = torch.load(args.model) + if 'epoch' in checkpoint: + print(f'Epoch {checkpoint["epoch"]}') weight = checkpoint['model'] if 'model' in checkpoint else checkpoint model.load_state_dict(weight)