diff --git a/README.md b/README.md
index 650275e..950c596 100644
--- a/README.md
+++ b/README.md
@@ -7,13 +7,20 @@ Zachary Teed and Jia Deng
+## Todos
+ - [ ] Train supervised to 200 (See if scratch is good enough)
+ - [ ] Train self-supervised to 200 (Using scratch / transfer)
+ - [ ] Run experiments of ctx size = { 96, 64 }
+ - [ ] Add self-collected dataset experiments
+ - [ ] Try-out context attention module ([GMFlowNet](https://github.com/xiaofeng94/GMFlowNet))
+
## Requirements
The code has been tested with PyTorch 1.6 and Cuda 10.1.
```Shell
conda create --name raft
conda activate raft
-conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
```
+Then follow the commands inside `setup-env.sh`
## Demos
Pretrained models can be downloaded by running
diff --git a/core/raft.py b/core/raft.py
index 652b81a..e9601f9 100644
--- a/core/raft.py
+++ b/core/raft.py
@@ -27,14 +27,14 @@ def __init__(self, args):
self.args = args
if args.small:
- self.hidden_dim = hdim = 96
- self.context_dim = cdim = 64
+ self.hidden_dim = hdim = (args.hidden // 4) * 3
+ self.context_dim = cdim = args.context // 2
args.corr_levels = 4
args.corr_radius = 3
else:
- self.hidden_dim = hdim = 128
- self.context_dim = cdim = 128
+ self.hidden_dim = hdim = args.hidden
+ self.context_dim = cdim = args.context
args.corr_levels = 4
args.corr_radius = 4
diff --git a/setup-env.sh b/setup-env.sh
index 354afba..0cbe21a 100644
--- a/setup-env.sh
+++ b/setup-env.sh
@@ -1,8 +1,10 @@
+#!/usr/bin/env bash
+
+echo Please run this script manually
exit 1
-conda create -n raft-dl2023
-conda activate raft-dl2023
+
conda install -y python=3.8
conda install -y cudatoolkit=11.1 -c conda-forge
conda install -y pytorch==1.8.0 torchvision==0.9.0 -c pytorch
-conda install -y tensorboard=2.10.0 matplotlib scipy tqdm
-pip install opencv-python
\ No newline at end of file
+conda install -y matplotlib scipy tensorboard=2.10.0 tqdm
+pip install opencv-python
diff --git a/train-selfsupervised.py b/train-selfsupervised.py
index 25accb3..c4b6c94 100644
--- a/train-selfsupervised.py
+++ b/train-selfsupervised.py
@@ -6,8 +6,9 @@
import argparse
import os
+from collections import OrderedDict
from pathlib import Path
-from typing import List
+from typing import Any, List
import numpy as np
import torch
@@ -101,8 +102,14 @@ def train(args):
print("Parameter Count: %d" % count_parameters(model))
if args.restore_ckpt is not None:
- model.load_state_dict(torch.load(args.restore_ckpt),
- strict=(not args.allow_nonstrict))
+ checkpoint: OrderedDict[str, Any] = torch.load(args.restore_ckpt)
+ if args.reset_context:
+ weight = OrderedDict()
+ for key, val in checkpoint.items():
+ if '.cnet.' not in key:
+ weight[key] = val
+ checkpoint = weight
+ model.load_state_dict(checkpoint, strict=(not args.allow_nonstrict))
model.cuda()
model.train()
@@ -206,7 +213,16 @@ def train(args):
parser.add_argument('--gamma', type=float, default=0.8,
help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')
+
+ parser.add_argument('--hidden', type=int, default=128,
+ help='The hidden size of the updater')
+ parser.add_argument('--context', type=int, default=128,
+ help='The context size of the updater')
+ parser.add_argument('--reset_context', action='store_true')
+
args = parser.parse_args()
+ if args.hidden != 128 or args.context != 128:
+ args.reset_context = True
torch.manual_seed(1234)
np.random.seed(1234)
diff --git a/train-supervised.py b/train-supervised.py
index cf594ad..484a047 100644
--- a/train-supervised.py
+++ b/train-supervised.py
@@ -6,7 +6,9 @@
import argparse
import os
+from collections import OrderedDict
from pathlib import Path
+from typing import Any
import numpy as np
import torch
@@ -94,8 +96,14 @@ def train(args):
print("Parameter Count: %d" % count_parameters(model))
if args.restore_ckpt is not None:
- model.load_state_dict(torch.load(args.restore_ckpt),
- strict=(not args.allow_nonstrict))
+ checkpoint: OrderedDict[str, Any] = torch.load(args.restore_ckpt)
+ if args.reset_context:
+ weight = OrderedDict()
+ for key, val in checkpoint.items():
+ if '.cnet.' not in key:
+ weight[key] = val
+ checkpoint = weight
+ model.load_state_dict(checkpoint, strict=(not args.allow_nonstrict))
model.cuda()
model.train()
@@ -199,7 +207,16 @@ def train(args):
parser.add_argument('--gamma', type=float, default=0.8,
help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')
+
+ parser.add_argument('--hidden', type=int, default=128,
+ help='The hidden size of the updater')
+ parser.add_argument('--context', type=int, default=128,
+ help='The context size of the updater')
+ parser.add_argument('--reset_context', action='store_true')
+
args = parser.parse_args()
+ if args.hidden != 128 or args.context != 128:
+ args.reset_context = True
torch.manual_seed(1234)
np.random.seed(1234)