-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
37939ac
commit 841619c
Showing
6 changed files
with
93 additions
and
3 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import torch | ||
|
||
from imageio import imread, imsave | ||
from scipy.misc import imresize | ||
import numpy as np | ||
from path import Path | ||
import argparse | ||
from tqdm import tqdm | ||
|
||
from models import DispResNet | ||
from utils import tensor2array | ||
|
||
parser = argparse.ArgumentParser(description='Inference script for DispNet learned with \ | ||
Structure from Motion Learner inference on KITTI Dataset', | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument("--output-disp", action='store_true', help="save disparity img") | ||
parser.add_argument("--output-depth", action='store_true', help="save depth img") | ||
parser.add_argument("--pretrained", required=True, type=str, help="pretrained DispResNet path") | ||
parser.add_argument("--img-height", default=256, type=int, help="Image height") | ||
parser.add_argument("--img-width", default=832, type=int, help="Image width") | ||
parser.add_argument("--no-resize", action='store_true', help="no resizing is done") | ||
|
||
parser.add_argument("--dataset-list", default=None, type=str, help="Dataset list file") | ||
parser.add_argument("--dataset-dir", default='.', type=str, help="Dataset directory") | ||
parser.add_argument("--output-dir", default='output', type=str, help="Output directory") | ||
parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob") | ||
parser.add_argument('--resnet-layers', required=True, type=int, default=18, choices=[18, 50], | ||
help='depth network architecture.') | ||
|
||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | ||
|
||
|
||
@torch.no_grad() | ||
def main(): | ||
args = parser.parse_args() | ||
if not(args.output_disp or args.output_depth): | ||
print('You must at least output one value !') | ||
return | ||
|
||
disp_net = DispResNet(args.resnet_layers, False).to(device) | ||
weights = torch.load(args.pretrained) | ||
disp_net.load_state_dict(weights['state_dict']) | ||
disp_net.eval() | ||
|
||
dataset_dir = Path(args.dataset_dir) | ||
output_dir = Path(args.output_dir) | ||
output_dir.makedirs_p() | ||
|
||
if args.dataset_list is not None: | ||
with open(args.dataset_list, 'r') as f: | ||
test_files = [dataset_dir/file for file in f.read().splitlines()] | ||
else: | ||
test_files = sum([dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts], []) | ||
|
||
print('{} files to test'.format(len(test_files))) | ||
|
||
for file in tqdm(test_files): | ||
|
||
img = imread(file).astype(np.float32) | ||
|
||
h, w, _ = img.shape | ||
if (not args.no_resize) and (h != args.img_height or w != args.img_width): | ||
img = imresize(img, (args.img_height, args.img_width)).astype(np.float32) | ||
img = np.transpose(img, (2, 0, 1)) | ||
|
||
tensor_img = torch.from_numpy(img).unsqueeze(0) | ||
tensor_img = ((tensor_img/255 - 0.45)/0.225).to(device) | ||
|
||
output = disp_net(tensor_img)[0] | ||
|
||
file_path, file_ext = file.relpath(args.dataset_dir).splitext() | ||
file_name = '-'.join(file_path.splitall()) | ||
|
||
if args.output_disp: | ||
disp = (255*tensor2array(output, max_value=None, colormap='bone')).astype(np.uint8) | ||
imsave(output_dir/'{}_disp{}'.format(file_name, file_ext), np.transpose(disp, (1, 2, 0))) | ||
if args.output_depth: | ||
depth = 1/output | ||
depth = (255*tensor2array(depth, max_value=10, colormap='rainbow')).astype(np.uint8) | ||
imsave(output_dir/'{}_depth{}'.format(file_name, file_ext), np.transpose(depth, (1, 2, 0))) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
INPUT_DIR=/media/bjw/Disk/Dataset/kitti_odometry/sequences/09/image_2 | ||
OUTPUT_DIR=results/ | ||
DISP_NET=checkpoints/resnet18_depth_256/dispnet_model_best.pth.tar | ||
|
||
python3 run_inference.py --pretrained $DISP_NET --resnet-layers 18 \ | ||
--dataset-dir $INPUT_DIR --output-dir $OUTPUT_DIR --output-disp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters