Skip to content

Commit

Permalink
fixed problems with variational dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
Zach Teed committed May 25, 2020
1 parent dd91321 commit 3fac647
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 8 deletions.
Binary file added RAFT.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ This repository contains the source code for our paper:
[RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)<br/>
Zachary Teed and Jia Deng<br/>

<img src="RAFT.png">

## Requirements
Our code was tested using PyTorch 1.3.1 and Python 3. The following additional packages need to be installed

Expand Down Expand Up @@ -84,11 +86,11 @@ python train.py --name=kitti_ft --image_size 288 896 --dataset=kitti --num_steps
You can evaluate a model on Sintel and KITTI by running

```Shell
python evaluate.py --model=checkpoints/chairs+things.pth
python evaluate.py --model=models/chairs+things.pth
```

or the small model by including the `small` flag

```Shell
python evaluate.py --model=checkpoints/small.pth --small
python evaluate.py --model=models/small.pth --small
```
16 changes: 14 additions & 2 deletions core/modules/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,20 @@ def __init__(self, args, hidden_dim=96):
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)

self.drop_inp = VariationalHidDropout(dropout=args.dropout)
self.drop_net = VariationalHidDropout(dropout=args.dropout)

def reset_mask(self, net, inp):
self.drop_inp.reset_mask(inp)
self.drop_net.reset_mask(net)

def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)

if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)

inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
Expand All @@ -157,12 +169,12 @@ def reset_mask(self, net, inp):

def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)

if self.training:
net = self.drop_net(net)
inp = self.drop_inp(inp)


inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)

Expand Down
2 changes: 1 addition & 1 deletion core/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, args):
args.corr_levels = 4
args.corr_radius = 4

if 'dropout' not in args._get_kwargs():
if not hasattr(args, 'dropout'):
args.dropout = 0

# feature network, context network, and update block
Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# exclude extremly large displacements
MAX_FLOW = 1000
SUM_FREQ = 100
SUM_FREQ = 200
VAL_FREQ = 5000


Expand Down Expand Up @@ -56,7 +56,7 @@ def sequence_loss(flow_preds, flow_gt, valid):


def fetch_dataloader(args):
""" Create the data loader for the corresponding trainign set """
""" Create the data loader for the corresponding training set """

if args.dataset == 'chairs':
train_dataset = datasets.FlyingChairs(args, image_size=args.image_size)
Expand Down Expand Up @@ -86,7 +86,7 @@ def fetch_optimizer(args, model):
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)

scheduler = optim.lr_scheduler.OneCycleLR(optimizer, args.lr, args.num_steps,
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear', final_div_factor=1.0)
pct_start=0.2, cycle_momentum=False, anneal_strategy='linear')

return optimizer, scheduler

Expand Down

0 comments on commit 3fac647

Please sign in to comment.