Skip to content

Commit

Permalink
Demo
Browse files Browse the repository at this point in the history
  • Loading branch information
hm-ysjiang committed Jun 4, 2023
1 parent 788de8d commit 6cc7c32
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 67 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ build
correlation.egg-info

checkpoints
runs
runs
visualization
46 changes: 24 additions & 22 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,38 @@
import sys
sys.path.append('core')

import sys # nopep8

sys.path.append('core') # nopep8
import argparse
import glob
import os
from pathlib import Path

import cv2
import glob
import numpy as np
import torch
from PIL import Image

from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder



DEVICE = 'cuda'


def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None].to(DEVICE)


def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
img = img[0].permute(1, 2, 0).cpu().numpy()
flo = flo[0].permute(1, 2, 0).cpu().numpy()

# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)

# import matplotlib.pyplot as plt
# plt.imshow(img_flo / 255.0)
# plt.show()

cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
cv2.waitKey()
return img_flo[:, :, [2, 1, 0]].astype(np.uint8)


def demo(args):
Expand All @@ -49,8 +45,8 @@ def demo(args):

with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))
glob.glob(os.path.join(args.path, '*.jpg'))

images = sorted(images)
for imfile1, imfile2 in zip(images[:-1], images[1:]):
image1 = load_image(imfile1)
Expand All @@ -60,16 +56,22 @@ def demo(args):
image1, image2 = padder.pad(image1, image2)

flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
viz(image1, flow_up)
cv2.imwrite(f'visualization/{Path(imfile1).name}',
viz(image1, flow_up))


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--path', help="dataset for evaluation")
parser.add_argument('--model', type=str, help="restore checkpoint")
parser.add_argument('--path', type=str, default='demo-frames',
help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
parser.add_argument('--mixed_precision',
action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true',
help='use efficent correlation implementation')
args = parser.parse_args()

os.makedirs('visualization', exist_ok=True)

demo(args)
56 changes: 29 additions & 27 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,45 @@
import sys
import sys # nopep8

from tqdm import trange
sys.path.append('core')
sys.path.append('core') # nopep8

from PIL import Image
import argparse
import os
import time

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

import datasets
from utils import flow_viz
from utils import frame_utils

from raft import RAFT
from tqdm import trange
from utils import frame_utils
from utils.utils import InputPadder, forward_interpolate

import datasets


@torch.no_grad()
def create_sintel_submission(model, iters=32, warm_start=False, output_path='sintel_submission'):
""" Create submission for the Sintel leaderboard """
model.eval()
for dstype in ['clean', 'final']:
test_dataset = datasets.MpiSintel(split='test', aug_params=None, dstype=dstype)

test_dataset = datasets.MpiSintel(
split='test', aug_params=None, dstype=dstype)

flow_prev, sequence_prev = None, None
for test_id in range(len(test_dataset)):
image1, image2, (sequence, frame) = test_dataset[test_id]
if sequence != sequence_prev:
flow_prev = None

padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
image1, image2 = padder.pad(
image1[None].cuda(), image2[None].cuda())

flow_low, flow_pr = model(image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
flow_low, flow_pr = model(
image1, image2, iters=iters, flow_init=flow_prev, test_mode=True)
flow = padder.unpad(flow_pr[0]).permute(1, 2, 0).cpu().numpy()

if warm_start:
flow_prev = forward_interpolate(flow_low[0])[None].cuda()

output_dir = os.path.join(output_path, dstype, sequence)
output_file = os.path.join(output_dir, 'frame%04d.flo' % (frame+1))

Expand Down Expand Up @@ -112,19 +110,21 @@ def validate_sintel(model, iters=32):
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)

flow_low, flow_pr = model(image1, image2, iters=iters, test_mode=True)
flow_low, flow_pr = model(
image1, image2, iters=iters, test_mode=True)
flow = padder.unpad(flow_pr[0]).cpu()

epe = torch.sum((flow - flow_gt)**2, dim=0).sqrt()
epe_list.append(epe.view(-1).numpy())

epe_all = np.concatenate(epe_list)
epe = np.mean(epe_all)
px1 = np.mean(epe_all<1)
px3 = np.mean(epe_all<3)
px5 = np.mean(epe_all<5)
px1 = np.mean(epe_all < 1)
px3 = np.mean(epe_all < 3)
px5 = np.mean(epe_all < 5)

print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" % (dstype, epe, px1, px3, px5))
print("Validation (%s) EPE: %f, 1px: %f, 3px: %f, 5px: %f" %
(dstype, epe, px1, px3, px5))
results[dstype] = np.mean(epe_list)

return results
Expand Down Expand Up @@ -174,8 +174,10 @@ def validate_kitti(model, iters=24):
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--dataset', help="dataset for evaluation")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
parser.add_argument('--mixed_precision',
action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true',
help='use efficent correlation implementation')
args = parser.parse_args()

model = torch.nn.DataParallel(RAFT(args))
Expand All @@ -189,12 +191,12 @@ def validate_kitti(model, iters=24):

with torch.no_grad():
if args.dataset == 'chairs':
raise NotImplementedError
validate_chairs(model.module)

elif args.dataset == 'sintel':
validate_sintel(model.module)

elif args.dataset == 'kitti':
raise NotImplementedError
validate_kitti(model.module)


12 changes: 3 additions & 9 deletions train-selfsupervised.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,21 @@
from __future__ import division, print_function

import sys # nopep8
from pathlib import Path # nopep8

sys.path.append('core') # nopep8

import argparse
import os
import time
from pathlib import Path
from typing import List

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn_backend
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from logger import Logger
from raft import RAFT
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils.utils import photometric_error

import datasets
Expand Down Expand Up @@ -70,7 +63,8 @@ def sequence_loss(flow_preds: List[torch.Tensor], flow_gt: torch.Tensor,

for i in range(n_predictions):
i_weight = gamma**(n_predictions - i - 1)
l1_err, ssim_err = photometric_error(image1, image2, flow_preds[i], valid[:, None])
l1_err, ssim_err = photometric_error(
image1, image2, flow_preds[i], valid[:, None])
i_loss = (1 - SSIM_WEIGHT) * l1_err + SSIM_WEIGHT * ssim_err
flow_loss += i_weight * i_loss

Expand Down
9 changes: 1 addition & 8 deletions train-supervised.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,20 @@
from __future__ import division, print_function
from pathlib import Path # nopep8

import sys # nopep8

sys.path.append('core') # nopep8

import argparse
import os
import time
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn_backend
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from logger import Logger
from raft import RAFT
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

import datasets
import evaluate
Expand Down

0 comments on commit 6cc7c32

Please sign in to comment.