Skip to content

Commit

Permalink
Update code to support running on different dstype of Sintel
Browse files Browse the repository at this point in the history
  • Loading branch information
hm-ysjiang committed Jun 8, 2023
1 parent c41db70 commit 5229455
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 116 deletions.
36 changes: 36 additions & 0 deletions ablation-upsampling.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
baseline="python -u train-supervised.py \
--name upsampling-baseline \
--num_epochs 100 \
--batch_size 3 \
--lr 0.0004 \
--wdecay 0.00001"

plus_l1="python -u train-supervised.py \
--name upsampling-plusl1 \
--num_epochs 100 \
--batch_size 3 \
--lr 0.0004 \
--wdecay 0.00001 \
--wloss_l1recon 1.0"

plus_ssim="python -u train-supervised.py \
--name upsampling-plusssim \
--num_epochs 100 \
--batch_size 3 \
--lr 0.0004 \
--wdecay 0.00001 \
--wloss_ssimrecon 1.0"

full="python -u train-supervised.py \
--name upsampling-full \
--num_epochs 100 \
--batch_size 3 \
--lr 0.0004 \
--wdecay 0.00001 \
--wloss_l1recon 1.0 \
--wloss_ssimrecon 1.0"

cmd=$baseline # Change this line

echo ${cmd}
eval ${cmd}
51 changes: 35 additions & 16 deletions core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@


class Sintel(data.Dataset):
def __init__(self, aug_params=None, split: Literal['training', 'validate'] = 'training', root='datasets/Sintel', dstype='clean'):
def __init__(self, aug_params=None, split: Literal['training', 'validate'] = 'training', root='datasets/Sintel', dstype: Literal['clean', 'final'] = 'clean'):
self.augmentor = None
if aug_params is not None:
self.augmentor = SparseFlowAugmentor(**aug_params)
Expand All @@ -135,7 +135,13 @@ def __init__(self, aug_params=None, split: Literal['training', 'validate'] = 'tr
invalid_root = osp.join(root, 'training', 'invalid')
occ_root = osp.join(root, 'training', 'occlusions')

for scene in os.listdir(image_root):
validation_set = {'ambush_5', 'bandage_2', 'market_5', 'temple_2'}
if self.split == 'training':
scene_sets = set(os.listdir(image_root)) - validation_set
else:
scene_sets = validation_set

for scene in scene_sets:
image_list = []
flow_list = []
invalid_list = []
Expand All @@ -150,17 +156,22 @@ def __init__(self, aug_params=None, split: Literal['training', 'validate'] = 'tr
scene, '*.png')))
occ_list += sorted(glob(osp.join(occ_root, scene, '*.png')))

split_index = n_pairs // 10
if self.split == 'training':
self.image_list += image_list[split_index:]
self.flow_list += flow_list[split_index:]
self.invalid_list += invalid_list[split_index:]
self.occ_list += occ_list[split_index:]
else:
self.image_list += image_list[:split_index]
self.flow_list += flow_list[:split_index]
self.invalid_list += invalid_list[:split_index]
self.occ_list += occ_list[:split_index]
self.image_list += image_list
self.flow_list += flow_list
self.invalid_list += invalid_list
self.occ_list += occ_list

# split_index = n_pairs // 10
# if self.split == 'training':
# self.image_list += image_list[split_index:]
# self.flow_list += flow_list[split_index:]
# self.invalid_list += invalid_list[split_index:]
# self.occ_list += occ_list[split_index:]
# else:
# self.image_list += image_list[:split_index]
# self.flow_list += flow_list[:split_index]
# self.invalid_list += invalid_list[:split_index]
# self.occ_list += occ_list[:split_index]

def __len__(self):
return len(self.image_list)
Expand Down Expand Up @@ -319,10 +330,18 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
# aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
# train_dataset = KITTI(aug_params, split='training')

train_dataset = Sintel({'crop_size': args.image_size,
'min_scale': -0.2, 'max_scale': 0, 'do_flip': True})
if args.dstype == 'mixed':
sintel_clean = Sintel({'crop_size': args.image_size,
'min_scale': -0.2, 'max_scale': 0, 'do_flip': True}, dstype='clean')
sintel_final = Sintel({'crop_size': args.image_size,
'min_scale': -0.2, 'max_scale': 0, 'do_flip': True}, dstype='final')
train_dataset = sintel_clean + sintel_final
else:
train_dataset = Sintel({'crop_size': args.image_size,
'min_scale': -0.2, 'max_scale': 0, 'do_flip': True}, dstype=args.dstype)

train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
pin_memory=True, shuffle=True, num_workers=4, drop_last=True)

print('Training with %d image pairs' % len(train_dataset))
return train_loader
23 changes: 10 additions & 13 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import argparse
import os
from typing import List, Literal

import numpy as np
import torch
Expand Down Expand Up @@ -93,12 +94,12 @@ def validate_chairs(model, iters=24):


@torch.no_grad()
def validate_sintel(model, iters=32):
def validate_sintel(model, iters=32, dstypes: List[Literal['clean', 'final']] = ['clean']):
""" Peform validation using the Sintel (train) split """
model.eval()
results = {}
with torch.no_grad():
for dstype in ['clean']:
for dstype in dstypes:
val_dataset = datasets.Sintel(split='validate', dstype=dstype)
epe_list = []

Expand Down Expand Up @@ -172,7 +173,9 @@ def validate_kitti(model, iters=24):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--dataset', default='sintel', help="dataset for evaluation")
parser.add_argument('--dstype', default='clean',
choices=['clean', 'final', 'mixed'],
help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision',
action='store_true', help='use mixed precision')
Expand All @@ -198,13 +201,7 @@ def validate_kitti(model, iters=24):
# create_kitti_submission(model.module)

with torch.no_grad():
if args.dataset == 'chairs':
raise NotImplementedError
validate_chairs(model.module)

elif args.dataset == 'sintel':
validate_sintel(model.module)

elif args.dataset == 'kitti':
raise NotImplementedError
validate_kitti(model.module)
if args.dstype == 'mixed':
validate_sintel(model.module, dstypes=['clean', 'final'])
else:
validate_sintel(model.module, dstypes=[args.dstype])
15 changes: 12 additions & 3 deletions test_sintel_occ.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/hmysjiang/miniconda3/envs/raft-dl2023/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import torch\n",
Expand All @@ -32,8 +41,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Training Set Size: 956\n",
"Validate Set Size 85\n"
"Training Set Size: 845\n",
"Validate Set Size 196\n"
]
}
],
Expand Down
47 changes: 29 additions & 18 deletions train-selfsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ def update(self):

def sequence_loss(flow_preds: List[torch.Tensor], flow_gt: torch.Tensor,
image1: torch.Tensor, image2: torch.Tensor,
valid: torch.Tensor, gamma=0.8, max_flow=MAX_FLOW):
valid: torch.Tensor, args, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """

n_predictions = len(flow_preds)
gamma = args.gamma
flow_loss: torch.Tensor = 0.0

# exlude invalid pixels and extremely large diplacements
Expand Down Expand Up @@ -112,15 +113,16 @@ def train(args):
best_evaluation = None
if args.restore_ckpt is not None:
checkpoint = torch.load(args.restore_ckpt)
weight: OrderedDict[str, Any] = checkpoint['model'] if 'model' in checkpoint else checkpoint
weight: OrderedDict[str, Any] = \
checkpoint['model'] if 'model' in checkpoint else checkpoint
if args.reset_context:
_weight = OrderedDict()
for key, val in checkpoint.items():
if args.context != 128 and \
('.cnet.' in key or '.update_block.gru.' in key):
('.cnet.' in key or '.update_block.gru.' in key):
pass
elif args.hidden != 128 and \
('.update_block.gru.' in key or '.update_block.flow_head.' in key):
('.update_block.gru.' in key or '.update_block.flow_head.' in key):
pass
else:
_weight[key] = val
Expand Down Expand Up @@ -159,7 +161,7 @@ def train(args):

loss, metrics = sequence_loss(flow_predictions, flow,
image1, image2, valid,
args.gamma)
args)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
Expand All @@ -171,36 +173,42 @@ def train(args):
logger.push({'loss': loss.item()})

logger.closePbar()
PATH = 'checkpoints/%s/model.pth' % args.name
torch.save({
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler,
'best_evaluation': best_evaluation
}, PATH)
}, f'checkpoints/{args.name}/model.pth')

if (epoch + 1) % 50 == 0:
torch.save({
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler,
'best_evaluation': best_evaluation
}, f'checkpoints/{args.name}/model-{epoch + 1}.pth')

results = {}
for val_dataset in args.validation:
if val_dataset == 'chairs':
results.update(evaluate.validate_chairs(model.module))
elif val_dataset == 'sintel':
results.update(evaluate.validate_sintel(model.module))
elif val_dataset == 'kitti':
results.update(evaluate.validate_kitti(model.module))
if args.validation == 'mixed':
results.update(evaluate.validate_sintel(model.module,
dstypes=['clean', 'final']))
else:
results.update(evaluate.validate_sintel(model.module,
dstypes=[args.validation]))
logger.write_dict(results, 'epoch')

evaluation_score = np.mean(list(results.values()))
if best_evaluation is None or evaluation_score < best_evaluation:
best_evaluation = evaluation_score
PATH = 'checkpoints/%s/model-best.pth' % args.name
torch.save({
'epoch': epoch + 1,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler,
'best_evaluation': best_evaluation
}, PATH)
}, f'checkpoints/{args.name}/model-best.pth')

model.train()
if args.freeze_bn:
Expand All @@ -220,7 +228,10 @@ def train(args):
parser.add_argument('--allow_nonstrict', action='store_true',
help='allow non-strict loading')
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--validation', type=str, nargs='+')
parser.add_argument('--dstype', type=str, default='clean',
choices=['clean', 'final', 'mixed'])
parser.add_argument('--validation', type=str, default='clean',
choices=['clean', 'final', 'mixed'])

parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--num_epochs', type=int, default=10)
Expand Down Expand Up @@ -249,7 +260,7 @@ def train(args):
args = parser.parse_args()
if args.hidden != 128 or args.context != 128:
args.reset_context = True
args.name = f'{args.name}-ep{args.num_epochs}-c{args.context}'
args.name = f'{args.name}-{args.dstype}-ep{args.num_epochs}-c{args.context}'

torch.manual_seed(1234)
np.random.seed(1234)
Expand Down
Loading

0 comments on commit 5229455

Please sign in to comment.