Skip to content

Commit

Permalink
added small 1M paramter model
Browse files Browse the repository at this point in the history
  • Loading branch information
Zach Teed committed Aug 24, 2020
1 parent c86b3dc commit 01ad964
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 27 deletions.
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,13 @@ Pretrained models can be downloaded by running
```Shell
./download_models.sh
```
or downloaded from [google drive](https://drive.google.com/file/d/10-BYgHqRNPGvmNUWr8razjb1xHu55pyA/view?usp=sharing)
or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing)

You can demo a trained model on a sequence of frames
```Shell
python demo.py --model=models/raft-things.pth --path=demo-frames
```

## (Optional) Efficent Implementation
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
```Shell
cd alt_cuda_corr && python setup.py install && cd ..
```
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag.Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.


## Required Data
To evaluate/train RAFT, you will need to download the required datasets.
* [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs)
Expand Down Expand Up @@ -83,3 +75,10 @@ If you have a RTX GPU, training can be accelerated using mixed precision. You ca
```Shell
./train_mixed.sh
```

## (Optional) Efficent Implementation
You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension
```Shell
cd alt_cuda_corr && python setup.py install && cd ..
```
and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass.
2 changes: 1 addition & 1 deletion download_models.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash
wget https://www.dropbox.com/s/npt24nvhoojdr0n/models.zip
wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip
unzip models.zip
18 changes: 5 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def update(self):
VAL_FREQ = 5000


def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """

n_predictions = len(flow_preds)
Expand All @@ -55,7 +55,7 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
valid = (valid >= 0.5) & (mag < max_flow)

for i in range(n_predictions):
i_weight = 0.8**(n_predictions - i - 1)
i_weight = gamma**(n_predictions - i - 1)
i_loss = (flow_preds[i] - flow_gt).abs()
flow_loss += i_weight * (valid[:, None] * i_loss).mean()

Expand All @@ -71,16 +71,11 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):

return flow_loss, metrics

def show_image(img):
img = img.permute(1,2,0).cpu().numpy()
plt.imshow(img/255.0)
plt.show()
# cv2.imshow('image', img/255.0)
# cv2.waitKey()

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)


def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
Expand Down Expand Up @@ -169,17 +164,14 @@ def train(args):
optimizer.zero_grad()
image1, image2, flow, valid = [x.cuda() for x in data_blob]

# show_image(image1[0])
# show_image(image2[0])

if args.add_noise:
stdv = np.random.uniform(0.0, 5.0)
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
image2 = (image2 + stdv * torch.randn(*image2.shape).cuda()).clamp(0.0, 255.0)

flow_predictions = model(image1, image2, iters=args.iters)

loss, metrics = sequence_loss(flow_predictions, flow, valid)
loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
Expand All @@ -188,7 +180,6 @@ def train(args):
scheduler.step()
scaler.update()


logger.push(metrics)

if total_steps % VAL_FREQ == VAL_FREQ - 1:
Expand Down Expand Up @@ -243,6 +234,7 @@ def train(args):
parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')
args = parser.parse_args()

Expand Down
4 changes: 2 additions & 2 deletions train_mixed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
mkdir -p checkpoints
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision
4 changes: 2 additions & 2 deletions train_standard.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
mkdir -p checkpoints
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85

0 comments on commit 01ad964

Please sign in to comment.