Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
JiawangBian committed Apr 3, 2020
1 parent 37939ac commit 841619c
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 3 deletions.
Binary file removed misc/mask.png
Binary file not shown.
Binary file removed misc/vo.png
Binary file not shown.
84 changes: 84 additions & 0 deletions run_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch

from imageio import imread, imsave
from scipy.misc import imresize
import numpy as np
from path import Path
import argparse
from tqdm import tqdm

from models import DispResNet
from utils import tensor2array

parser = argparse.ArgumentParser(description='Inference script for DispNet learned with \
Structure from Motion Learner inference on KITTI Dataset',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--output-disp", action='store_true', help="save disparity img")
parser.add_argument("--output-depth", action='store_true', help="save depth img")
parser.add_argument("--pretrained", required=True, type=str, help="pretrained DispResNet path")
parser.add_argument("--img-height", default=256, type=int, help="Image height")
parser.add_argument("--img-width", default=832, type=int, help="Image width")
parser.add_argument("--no-resize", action='store_true', help="no resizing is done")

parser.add_argument("--dataset-list", default=None, type=str, help="Dataset list file")
parser.add_argument("--dataset-dir", default='.', type=str, help="Dataset directory")
parser.add_argument("--output-dir", default='output', type=str, help="Output directory")
parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob")
parser.add_argument('--resnet-layers', required=True, type=int, default=18, choices=[18, 50],
help='depth network architecture.')

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


@torch.no_grad()
def main():
args = parser.parse_args()
if not(args.output_disp or args.output_depth):
print('You must at least output one value !')
return

disp_net = DispResNet(args.resnet_layers, False).to(device)
weights = torch.load(args.pretrained)
disp_net.load_state_dict(weights['state_dict'])
disp_net.eval()

dataset_dir = Path(args.dataset_dir)
output_dir = Path(args.output_dir)
output_dir.makedirs_p()

if args.dataset_list is not None:
with open(args.dataset_list, 'r') as f:
test_files = [dataset_dir/file for file in f.read().splitlines()]
else:
test_files = sum([dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts], [])

print('{} files to test'.format(len(test_files)))

for file in tqdm(test_files):

img = imread(file).astype(np.float32)

h, w, _ = img.shape
if (not args.no_resize) and (h != args.img_height or w != args.img_width):
img = imresize(img, (args.img_height, args.img_width)).astype(np.float32)
img = np.transpose(img, (2, 0, 1))

tensor_img = torch.from_numpy(img).unsqueeze(0)
tensor_img = ((tensor_img/255 - 0.45)/0.225).to(device)

output = disp_net(tensor_img)[0]

file_path, file_ext = file.relpath(args.dataset_dir).splitext()
file_name = '-'.join(file_path.splitall())

if args.output_disp:
disp = (255*tensor2array(output, max_value=None, colormap='bone')).astype(np.uint8)
imsave(output_dir/'{}_disp{}'.format(file_name, file_ext), np.transpose(disp, (1, 2, 0)))
if args.output_depth:
depth = 1/output
depth = (255*tensor2array(depth, max_value=10, colormap='rainbow')).astype(np.uint8)
imsave(output_dir/'{}_depth{}'.format(file_name, file_ext), np.transpose(depth, (1, 2, 0)))


if __name__ == '__main__':
main()
6 changes: 6 additions & 0 deletions scripts/run_inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
INPUT_DIR=/media/bjw/Disk/Dataset/kitti_odometry/sequences/09/image_2
OUTPUT_DIR=results/
DISP_NET=checkpoints/resnet18_depth_256/dispnet_model_best.pth.tar

python3 run_inference.py --pretrained $DISP_NET --resnet-layers 18 \
--dataset-dir $INPUT_DIR --output-dir $OUTPUT_DIR --output-disp
4 changes: 2 additions & 2 deletions scripts/test_kitti_pose.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
POSE_NET=checkpoints/resnet50_pose_256/10-22-18:36/exp_pose_model_best.pth.tar
KITIT_VO=/media/bjw/Disk/Dataset/kitti_odom/
POSE_NET=checkpoints/resnet50_pose_256/exp_pose_model_best.pth.tar
KITIT_VO=/media/bjw/Disk/Dataset/kitti_odometry/

python test_pose.py $POSE_NET \
--img-height 256 --img-width 832 \
Expand Down
2 changes: 1 addition & 1 deletion test_disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

parser = argparse.ArgumentParser(description='Script for DispNet testing with corresponding groundTruth',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--pretrained-dispnet", required=True, type=str, help="pretrained DispNet path")
parser.add_argument("--pretrained-dispnet", required=True, type=str, help="pretrained DispResNet path")
parser.add_argument("--img-height", default=256, type=int, help="Image height")
parser.add_argument("--img-width", default=832, type=int, help="Image width")
parser.add_argument("--min-depth", default=1e-3)
Expand Down

0 comments on commit 841619c

Please sign in to comment.