diff --git a/README.md b/README.md index 650275e..950c596 100644 --- a/README.md +++ b/README.md @@ -7,13 +7,20 @@ Zachary Teed and Jia Deng
+## Todos + - [ ] Train supervised to 200 (See if scratch is good enough) + - [ ] Train self-supervised to 200 (Using scratch / transfer) + - [ ] Run experiments of ctx size = { 96, 64 } + - [ ] Add self-collected dataset experiments + - [ ] Try-out context attention module ([GMFlowNet](https://github.com/xiaofeng94/GMFlowNet)) + ## Requirements The code has been tested with PyTorch 1.6 and Cuda 10.1. ```Shell conda create --name raft conda activate raft -conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch ``` +Then follow the commands inside `setup-env.sh` ## Demos Pretrained models can be downloaded by running diff --git a/core/raft.py b/core/raft.py index 652b81a..e9601f9 100644 --- a/core/raft.py +++ b/core/raft.py @@ -27,14 +27,14 @@ def __init__(self, args): self.args = args if args.small: - self.hidden_dim = hdim = 96 - self.context_dim = cdim = 64 + self.hidden_dim = hdim = (args.hidden // 4) * 3 + self.context_dim = cdim = args.context // 2 args.corr_levels = 4 args.corr_radius = 3 else: - self.hidden_dim = hdim = 128 - self.context_dim = cdim = 128 + self.hidden_dim = hdim = args.hidden + self.context_dim = cdim = args.context args.corr_levels = 4 args.corr_radius = 4 diff --git a/setup-env.sh b/setup-env.sh index 354afba..0cbe21a 100644 --- a/setup-env.sh +++ b/setup-env.sh @@ -1,8 +1,10 @@ +#!/usr/bin/env bash + +echo Please run this script manually exit 1 -conda create -n raft-dl2023 -conda activate raft-dl2023 + conda install -y python=3.8 conda install -y cudatoolkit=11.1 -c conda-forge conda install -y pytorch==1.8.0 torchvision==0.9.0 -c pytorch -conda install -y tensorboard=2.10.0 matplotlib scipy tqdm -pip install opencv-python \ No newline at end of file +conda install -y matplotlib scipy tensorboard=2.10.0 tqdm +pip install opencv-python diff --git a/train-selfsupervised.py b/train-selfsupervised.py index 25accb3..c4b6c94 100644 --- a/train-selfsupervised.py +++ b/train-selfsupervised.py @@ -6,8 +6,9 @@ import argparse import os +from collections import OrderedDict from pathlib import Path -from typing import List +from typing import Any, List import numpy as np import torch @@ -101,8 +102,14 @@ def train(args): print("Parameter Count: %d" % count_parameters(model)) if args.restore_ckpt is not None: - model.load_state_dict(torch.load(args.restore_ckpt), - strict=(not args.allow_nonstrict)) + checkpoint: OrderedDict[str, Any] = torch.load(args.restore_ckpt) + if args.reset_context: + weight = OrderedDict() + for key, val in checkpoint.items(): + if '.cnet.' not in key: + weight[key] = val + checkpoint = weight + model.load_state_dict(checkpoint, strict=(not args.allow_nonstrict)) model.cuda() model.train() @@ -206,7 +213,16 @@ def train(args): parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') parser.add_argument('--add_noise', action='store_true') + + parser.add_argument('--hidden', type=int, default=128, + help='The hidden size of the updater') + parser.add_argument('--context', type=int, default=128, + help='The context size of the updater') + parser.add_argument('--reset_context', action='store_true') + args = parser.parse_args() + if args.hidden != 128 or args.context != 128: + args.reset_context = True torch.manual_seed(1234) np.random.seed(1234) diff --git a/train-supervised.py b/train-supervised.py index cf594ad..484a047 100644 --- a/train-supervised.py +++ b/train-supervised.py @@ -6,7 +6,9 @@ import argparse import os +from collections import OrderedDict from pathlib import Path +from typing import Any import numpy as np import torch @@ -94,8 +96,14 @@ def train(args): print("Parameter Count: %d" % count_parameters(model)) if args.restore_ckpt is not None: - model.load_state_dict(torch.load(args.restore_ckpt), - strict=(not args.allow_nonstrict)) + checkpoint: OrderedDict[str, Any] = torch.load(args.restore_ckpt) + if args.reset_context: + weight = OrderedDict() + for key, val in checkpoint.items(): + if '.cnet.' not in key: + weight[key] = val + checkpoint = weight + model.load_state_dict(checkpoint, strict=(not args.allow_nonstrict)) model.cuda() model.train() @@ -199,7 +207,16 @@ def train(args): parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') parser.add_argument('--add_noise', action='store_true') + + parser.add_argument('--hidden', type=int, default=128, + help='The hidden size of the updater') + parser.add_argument('--context', type=int, default=128, + help='The context size of the updater') + parser.add_argument('--reset_context', action='store_true') + args = parser.parse_args() + if args.hidden != 128 or args.context != 128: + args.reset_context = True torch.manual_seed(1234) np.random.seed(1234)