Skip to content

Commit

Permalink
Update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
hm-ysjiang committed Jun 4, 2023
1 parent 6cc7c32 commit 3fb73f7
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 14 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@ Zachary Teed and Jia Deng<br/>

<img src="RAFT.png">

## Todos
- [ ] Train supervised to 200 (See if scratch is good enough)
- [ ] Train self-supervised to 200 (Using scratch / transfer)
- [ ] Run experiments of ctx size = { 96, 64 }
- [ ] Add self-collected dataset experiments
- [ ] Try-out context attention module ([GMFlowNet](https://github.com/xiaofeng94/GMFlowNet))

## Requirements
The code has been tested with PyTorch 1.6 and Cuda 10.1.
```Shell
conda create --name raft
conda activate raft
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
```
Then follow the commands inside `setup-env.sh`

## Demos
Pretrained models can be downloaded by running
Expand Down
8 changes: 4 additions & 4 deletions core/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ def __init__(self, args):
self.args = args

if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
self.hidden_dim = hdim = (args.hidden // 4) * 3
self.context_dim = cdim = args.context // 2
args.corr_levels = 4
args.corr_radius = 3

else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
self.hidden_dim = hdim = args.hidden
self.context_dim = cdim = args.context
args.corr_levels = 4
args.corr_radius = 4

Expand Down
10 changes: 6 additions & 4 deletions setup-env.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#!/usr/bin/env bash

echo Please run this script manually
exit 1
conda create -n raft-dl2023
conda activate raft-dl2023

conda install -y python=3.8
conda install -y cudatoolkit=11.1 -c conda-forge
conda install -y pytorch==1.8.0 torchvision==0.9.0 -c pytorch
conda install -y tensorboard=2.10.0 matplotlib scipy tqdm
pip install opencv-python
conda install -y matplotlib scipy tensorboard=2.10.0 tqdm
pip install opencv-python
22 changes: 19 additions & 3 deletions train-selfsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import argparse
import os
from collections import OrderedDict
from pathlib import Path
from typing import List
from typing import Any, List

import numpy as np
import torch
Expand Down Expand Up @@ -101,8 +102,14 @@ def train(args):
print("Parameter Count: %d" % count_parameters(model))

if args.restore_ckpt is not None:
model.load_state_dict(torch.load(args.restore_ckpt),
strict=(not args.allow_nonstrict))
checkpoint: OrderedDict[str, Any] = torch.load(args.restore_ckpt)
if args.reset_context:
weight = OrderedDict()
for key, val in checkpoint.items():
if '.cnet.' not in key:
weight[key] = val
checkpoint = weight
model.load_state_dict(checkpoint, strict=(not args.allow_nonstrict))

model.cuda()
model.train()
Expand Down Expand Up @@ -206,7 +213,16 @@ def train(args):
parser.add_argument('--gamma', type=float, default=0.8,
help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')

parser.add_argument('--hidden', type=int, default=128,
help='The hidden size of the updater')
parser.add_argument('--context', type=int, default=128,
help='The context size of the updater')
parser.add_argument('--reset_context', action='store_true')

args = parser.parse_args()
if args.hidden != 128 or args.context != 128:
args.reset_context = True

torch.manual_seed(1234)
np.random.seed(1234)
Expand Down
21 changes: 19 additions & 2 deletions train-supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import argparse
import os
from collections import OrderedDict
from pathlib import Path
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -94,8 +96,14 @@ def train(args):
print("Parameter Count: %d" % count_parameters(model))

if args.restore_ckpt is not None:
model.load_state_dict(torch.load(args.restore_ckpt),
strict=(not args.allow_nonstrict))
checkpoint: OrderedDict[str, Any] = torch.load(args.restore_ckpt)
if args.reset_context:
weight = OrderedDict()
for key, val in checkpoint.items():
if '.cnet.' not in key:
weight[key] = val
checkpoint = weight
model.load_state_dict(checkpoint, strict=(not args.allow_nonstrict))

model.cuda()
model.train()
Expand Down Expand Up @@ -199,7 +207,16 @@ def train(args):
parser.add_argument('--gamma', type=float, default=0.8,
help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')

parser.add_argument('--hidden', type=int, default=128,
help='The hidden size of the updater')
parser.add_argument('--context', type=int, default=128,
help='The context size of the updater')
parser.add_argument('--reset_context', action='store_true')

args = parser.parse_args()
if args.hidden != 128 or args.context != 128:
args.reset_context = True

torch.manual_seed(1234)
np.random.seed(1234)
Expand Down

0 comments on commit 3fb73f7

Please sign in to comment.