Skip to content

Commit

Permalink
Sintel-tweaked supervised training
Browse files Browse the repository at this point in the history
  • Loading branch information
hm-ysjiang committed Jun 3, 2023
1 parent aac9dd5 commit 57a27c2
Show file tree
Hide file tree
Showing 11 changed files with 750 additions and 290 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ pytorch_env
models
build
correlation.egg-info

checkpoints
runs
5 changes: 5 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"python.analysis.extraPaths": [
"core"
]
}
405 changes: 249 additions & 156 deletions core/datasets.py

Large diffs are not rendered by default.

94 changes: 94 additions & 0 deletions core/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from pathlib import Path
from typing import Literal

from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm


class Logger:
def __init__(self, name, step_init=0, SUM_FREQ=20):
self.SUM_FREQ = SUM_FREQ
self.name = name

self.total_steps = step_init
self.epoch = 0
self.epoch_size = 0

self.running_loss = {}

self.writer = None
self.pbar = None

def _print_training_status(self):
for k in self.running_loss:
self.running_loss[k] /= self.SUM_FREQ

training_str = f'Ep {self.epoch:3d}'
if (total_loss := self.running_loss.get('loss', None)) is not None:
training_str += f'; loss {total_loss:3.3f}'
self.pbar.set_description(training_str)

if self.writer is None:
self.open()

self.writer.add_scalar('epoch', self.total_steps / self.epoch_size,
self.total_steps)
for k in self.running_loss:
self.writer.add_scalar(k, self.running_loss[k],
self.total_steps)
self.running_loss = {}

def push(self, metrics):
self.total_steps += 1
self.pbar.update(1)

for key in metrics:
if key not in self.running_loss:
self.running_loss[key] = 0.0

self.running_loss[key] += metrics[key]

if self.total_steps % self.SUM_FREQ == 0:
self._print_training_status()

def write_dict(self, results, rel: Literal['step', 'epoch'] = 'step'):
if rel not in ('step', 'epoch'):
raise ValueError(rel)

if self.writer is None:
self.open()

for key in results:
self.writer.add_scalar(key, results[key],
self.total_steps if rel == 'step' else self.total_steps / self.epoch_size)

def open(self):
rootdir = Path(__file__).parent.parent
self.writer = SummaryWriter(rootdir.joinpath('runs', self.name))

def close(self):
self.closePbar()
self.writer.close()

def initPbar(self, epoch_size, epoch, ncols=120):
self.epoch = epoch
self.epoch_size = epoch_size
self.pbar = tqdm(total=epoch_size, desc=f'Ep {epoch:3d}', ncols=ncols)

def closePbar(self, accuracies=None, lastLR=None):
if self.pbar is not None:
if accuracies is not None:
train, test = accuracies
desc_str = self.pbar.desc[:-2]
test_str = f'{test:.3f}'
if test > 0.87:
test_str = '\033[92m' + test_str + '\033[0m'
desc_str += f', accu ({train:.3f},{test_str}), l-LR {lastLR:.2e}'
self.pbar.set_description(desc_str)
self.pbar.close()

def write_viz(self, image):
if self.writer is None:
self.open()

self.writer.add_image('visualization', image, self.total_steps)
4 changes: 2 additions & 2 deletions core/utils/augmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def spatial_transform(self, img1, img2, flow, valid):


def __call__(self, img1, img2, flow, valid):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
# img1, img2 = self.color_transform(img1, img2)
# img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)

img1 = np.ascontiguousarray(img1)
Expand Down
43 changes: 23 additions & 20 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import sys

from tqdm import trange
sys.path.append('core')

from PIL import Image
Expand Down Expand Up @@ -97,32 +99,33 @@ def validate_sintel(model, iters=32):
""" Peform validation using the Sintel (train) split """
model.eval()
results = {}
for dstype in ['clean', 'final']:
val_dataset = datasets.MpiSintel(split='training', dstype=dstype)
epe_list = []
with torch.no_grad():
for dstype in ['clean']:
val_dataset = datasets.Sintel(split='validate', dstype=dstype)
epe_list = []

for val_id in range(len(val_dataset)):
image1, image2, flow_gt, _ = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()
for val_id in trange(len(val_dataset), desc='Evaluating', ncols=120):
image1, image2, flow_gt, _ = val_dataset[val_id]
image1 = image1[None].cuda()
image2 = image2[None].cuda()

padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)

flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).cpu()
flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).cpu()

epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())
epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())

epe_all = np.concatenate(epe_list)
epe = np.mean(epe_all)
px1 = np.mean(epe_all<1)
px3 = np.mean(epe_all<3)
px5 = np.mean(epe_all<5)
epe_all = np.concatenate(epe_list)
epe = np.mean(epe_all)
px1 = np.mean(epe_all<1)
px3 = np.mean(epe_all<3)
px5 = np.mean(epe_all<5)

print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
results[dstype] = np.mean(epe_list)
print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
results[dstype] = np.mean(epe_list)

return results

Expand Down
8 changes: 8 additions & 0 deletions setup-env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
exit 1
conda create -n raft-dl2023
conda activate raft-dl2023
conda install -y python=3.8
conda install -y cudatoolkit=11.1 -c conda-forge
conda install -y pytorch==1.8.0 torchvision==0.9.0 -c pytorch
conda install -y tensorboard=2.10.0 matplotlib scipy tqdm
pip install opencv-python
137 changes: 137 additions & 0 deletions test_kitti_occ.ipynb

Large diffs are not rendered by default.

126 changes: 126 additions & 0 deletions test_sintel_occ.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 57a27c2

Please sign in to comment.