Skip to content

Commit

Permalink
added training code
Browse files Browse the repository at this point in the history
  • Loading branch information
Zach Teed committed Jul 31, 2020
1 parent dc370f8 commit a1d8344
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 13 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ Zachary Teed and Jia Deng<br/>
<img src="RAFT.png">

## Requirements
The code has been tested with PyTorch 1.5.1 and PyTorch Nightly. If you want to train with mixed precision, you will have to install the nightly build.
The code has been tested with PyTorch 1.6 and Cuda 10.1.
```Shell
conda create --name raft
conda activate raft
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch-nightly
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch
conda install matplotlib
conda install tensorboard
conda install scipy
Expand Down Expand Up @@ -67,13 +67,12 @@ python evaluate.py --model=models/raft-things.pth --dataset=sintel
```

## Training
Training code will be made available in the next few days
<!-- We used the following training schedule in our paper (note: we use 2 GPUs for training). Training logs will be written to the `runs` which can be visualized using tensorboard
We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard
```Shell
./train_standard.sh
```

If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
```Shell
./train_mixed.sh
``` -->
```
10 changes: 5 additions & 5 deletions core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """

if args.stage == 'chairs':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True}
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
train_dataset = FlyingChairs(aug_params, split='training')

elif args.stage == 'things':
Expand All @@ -210,22 +210,22 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
train_dataset = clean_dataset + final_dataset

elif args.stage == 'sintel':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final')

if TRAIN_DS == 'C+T+K+S+H':
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True})
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things

elif TRAIN_DS == 'C+T+K/S':
train_dataset = 100*sintel_clean + 100*sintel_final + things

elif args.stage == 'kitti':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = KITTI(args, image_size=args.image_size, is_val=False)
train_dataset = KITTI(aug_params, split='training')

train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
Expand Down
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def update(self):


# exclude extremly large displacements
MAX_FLOW = 500
MAX_FLOW = 400
SUM_FREQ = 100
VAL_FREQ = 5000

Expand Down Expand Up @@ -181,13 +181,14 @@ def train(args):

loss, metrics = sequence_loss(flow_predictions, flow, valid)
scaler.scale(loss).backward()

scaler.unscale_(optimizer)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

scaler.step(optimizer)
scheduler.step()
scaler.update()


logger.push(metrics)

if total_steps % VAL_FREQ == VAL_FREQ - 1:
Expand Down
6 changes: 6 additions & 0 deletions train_mixed.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
mkdir -p checkpoints
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision
6 changes: 6 additions & 0 deletions train_standard.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
mkdir -p checkpoints
python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001
python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001

0 comments on commit a1d8344

Please sign in to comment.