diff --git a/img_synthesis_main.py b/img_synthesis_main.py new file mode 100644 index 0000000..e68338b --- /dev/null +++ b/img_synthesis_main.py @@ -0,0 +1,184 @@ + +from __future__ import absolute_import, division, print_function + +import os +import sys +import glob +import argparse +import numpy as np +import PIL.Image as pil +import matplotlib as mpl +import matplotlib.cm as cm +import cv2 + +import torch +from torchvision import transforms, datasets +import networks + +def parse_args(): + + parser = argparse.ArgumentParser( + description='Simple testing funtion for Monodepthv2 models.') + parser.add_argument('--image_path', type=str, + help='path to a test image or folder of images', required=True) + parser.add_argument('--model_name', type=str, + help='name of a pretrained model to use', + choices=[ + "mono_640x192", + "stereo_640x192", + "mono+stereo_640x192", + "mono_no_pt_640x192", + "stereo_no_pt_640x192", + "mono+stereo_no_pt_640x192", + "mono_1024x320", + "stereo_1024x320", + "mono+stereo_1024x320"]) + parser.add_argument('--ext', type=str, + help='image extension to search for in folder', default="jpg") + parser.add_argument("--no_cuda", + help='if set, disables CUDA', + action='store_true') + parser.add_argument("--pred_metric_depth", + help='if set, predicts metric depth instead of disparity. (This only ' + 'makes sense for stereo-trained KITTI models).', + action='store_true') + parser.add_argument('--output_image_path', type=str, + help='path to folder of output images', required=True) + parser.add_argument('--beta', type=float, + help='degree of haze', default=1.) + parser.add_argument('--airlight', type=float, + help='atmospheric light', default=255.) + + return parser.parse_args() + + +def gen_haze(clean_img, depth_img, beta=1.0, A = 150): + + depth_img_3c = np.zeros_like(clean_img) + depth_img_3c[:,:,0] = depth_img + depth_img_3c[:,:,1] = depth_img + depth_img_3c[:,:,2] = depth_img + + norm_depth_img = depth_img_3c/255 + trans = np.exp(-norm_depth_img*beta) + + hazy = clean_img*trans + A*(1-trans) + hazy = np.array(hazy, dtype=np.uint8) + + return hazy + + +def test_simple(args): + + assert args.model_name is not None, \ + "You must specify the --model_name parameter; see README.md for an example" + + if torch.cuda.is_available() and not args.no_cuda: + device = torch.device("cuda") + else: + device = torch.device("cpu") + + if args.pred_metric_depth and "stereo" not in args.model_name: + print("Warning: The --pred_metric_depth flag only makes sense for stereo-trained KITTI " + "models. For mono-trained models, output depths will not in metric space.") + + # download_model_if_doesnt_exist(args.model_name) + model_path = os.path.join("models", args.model_name) + print("-> Loading model from ", model_path) + encoder_path = os.path.join(model_path, "encoder.pth") + depth_decoder_path = os.path.join(model_path, "depth.pth") + + # LOADING PRETRAINED MODEL + print(" Loading pretrained encoder") + encoder = networks.ResnetEncoder(18, False) + loaded_dict_enc = torch.load(encoder_path, map_location=device) + + # EXTRACT THE HEIGHT AND WIDTH OF IMAGE THAT THIS MODEL WAS TRAINED WITH + feed_height = loaded_dict_enc['height'] + feed_width = loaded_dict_enc['width'] + filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()} + encoder.load_state_dict(filtered_dict_enc) + encoder.to(device) + encoder.eval() + + print(" Loading pretrained decoder") + depth_decoder = networks.DepthDecoder( + num_ch_enc=encoder.num_ch_enc, scales=range(4)) + + loaded_dict = torch.load(depth_decoder_path, map_location=device) + depth_decoder.load_state_dict(loaded_dict) + + depth_decoder.to(device) + depth_decoder.eval() + + # FINDING INPUT IMAGES + if os.path.isfile(args.image_path): + # Only testing on a single image + paths = [args.image_path] + output_directory = os.path.dirname(args.image_path) + elif os.path.isdir(args.image_path): + # Searching folder for images + paths = glob.glob(os.path.join(args.image_path, '*.{}'.format(args.ext))) + output_directory = args.image_path + else: + raise Exception("Can not find args.image_path: {}".format(args.image_path)) + + print("-> Predicting on {:d} test images".format(len(paths))) + + # CHECK IF OUTPUT FOLDER EXISTS + if not os.path.isdir(args.output_image_path): + os.makedirs(args.output_image_path) + + output_dir = args.output_image_path + + # PREDICTING ON EACH IMAGE IN TURN + with torch.no_grad(): + for idx, image_path in enumerate(paths): + + if image_path.endswith("_disp.jpg"): + # don't try to predict disparity for a disparity image! + continue + + # LOAD IMAGE AND PREPROCESS + input_image = pil.open(image_path).convert('RGB') + clean_img = input_image.copy() + original_width, original_height = input_image.size + input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS) + input_image = transforms.ToTensor()(input_image).unsqueeze(0) + + # PREDICTION + input_image = input_image.to(device) + features = encoder(input_image) + outputs = depth_decoder(features) + + disp = outputs[("disp", 0)] + disp_resized = torch.nn.functional.interpolate( + disp, (original_height, original_width), mode="bilinear", align_corners=False) + + # EXTRACT DEPTH IMAGE + disp_resized_np = disp_resized.squeeze().cpu().numpy() + vmax = np.percentile(disp_resized_np, 95) + normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax) + mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') + colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8) + im = pil.fromarray(colormapped_im) + gray_colormapped_im = cv2.cvtColor(colormapped_im, cv2.COLOR_RGB2GRAY) + inv_gray_colormapped_im = 255 - gray_colormapped_im + + # MAKE HAZY IMAGE: + # Change degree of haze by changing 'beta' (recommended value of beta: 0.5 - 3.0) + # High beta -> Thick haze + # Low beta -> Sparse haze + hazy = gen_haze(clean_img, inv_gray_colormapped_im, beta=args.beta, A=args.airlight) + + # SAVE FILES + output_name = os.path.splitext(os.path.basename(image_path))[0] + cv2.imwrite(f"{output_dir}/{output_name}_synt.jpg", cv2.cvtColor(hazy, cv2.COLOR_RGB2BGR)) + + print(" Processed {:d} of {:d} images".format(idx + 1, len(paths))) + + print(f'-> Done! Find outputs in {output_dir}') + +if __name__ == '__main__': + args = parse_args() + test_simple(args) diff --git a/layers.py b/layers.py new file mode 100644 index 0000000..070cadb --- /dev/null +++ b/layers.py @@ -0,0 +1,269 @@ +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +from __future__ import absolute_import, division, print_function + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def disp_to_depth(disp, min_depth, max_depth): + """Convert network's sigmoid output into depth prediction + The formula for this conversion is given in the 'additional considerations' + section of the paper. + """ + min_disp = 1 / max_depth + max_disp = 1 / min_depth + scaled_disp = min_disp + (max_disp - min_disp) * disp + depth = 1 / scaled_disp + return scaled_disp, depth + + +def transformation_from_parameters(axisangle, translation, invert=False): + """Convert the network's (axisangle, translation) output into a 4x4 matrix + """ + R = rot_from_axisangle(axisangle) + t = translation.clone() + + if invert: + R = R.transpose(1, 2) + t *= -1 + + T = get_translation_matrix(t) + + if invert: + M = torch.matmul(R, T) + else: + M = torch.matmul(T, R) + + return M + + +def get_translation_matrix(translation_vector): + """Convert a translation vector into a 4x4 transformation matrix + """ + T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) + + t = translation_vector.contiguous().view(-1, 3, 1) + + T[:, 0, 0] = 1 + T[:, 1, 1] = 1 + T[:, 2, 2] = 1 + T[:, 3, 3] = 1 + T[:, :3, 3, None] = t + + return T + + +def rot_from_axisangle(vec): + """Convert an axisangle rotation into a 4x4 transformation matrix + (adapted from https://github.com/Wallacoloo/printipi) + Input 'vec' has to be Bx1x3 + """ + angle = torch.norm(vec, 2, 2, True) + axis = vec / (angle + 1e-7) + + ca = torch.cos(angle) + sa = torch.sin(angle) + C = 1 - ca + + x = axis[..., 0].unsqueeze(1) + y = axis[..., 1].unsqueeze(1) + z = axis[..., 2].unsqueeze(1) + + xs = x * sa + ys = y * sa + zs = z * sa + xC = x * C + yC = y * C + zC = z * C + xyC = x * yC + yzC = y * zC + zxC = z * xC + + rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) + + rot[:, 0, 0] = torch.squeeze(x * xC + ca) + rot[:, 0, 1] = torch.squeeze(xyC - zs) + rot[:, 0, 2] = torch.squeeze(zxC + ys) + rot[:, 1, 0] = torch.squeeze(xyC + zs) + rot[:, 1, 1] = torch.squeeze(y * yC + ca) + rot[:, 1, 2] = torch.squeeze(yzC - xs) + rot[:, 2, 0] = torch.squeeze(zxC - ys) + rot[:, 2, 1] = torch.squeeze(yzC + xs) + rot[:, 2, 2] = torch.squeeze(z * zC + ca) + rot[:, 3, 3] = 1 + + return rot + + +class ConvBlock(nn.Module): + """Layer to perform a convolution followed by ELU + """ + def __init__(self, in_channels, out_channels): + super(ConvBlock, self).__init__() + + self.conv = Conv3x3(in_channels, out_channels) + self.nonlin = nn.ELU(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.nonlin(out) + return out + + +class Conv3x3(nn.Module): + """Layer to pad and convolve input + """ + def __init__(self, in_channels, out_channels, use_refl=True): + super(Conv3x3, self).__init__() + + if use_refl: + self.pad = nn.ReflectionPad2d(1) + else: + self.pad = nn.ZeroPad2d(1) + self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) + + def forward(self, x): + out = self.pad(x) + out = self.conv(out) + return out + + +class BackprojectDepth(nn.Module): + """Layer to transform a depth image into a point cloud + """ + def __init__(self, batch_size, height, width): + super(BackprojectDepth, self).__init__() + + self.batch_size = batch_size + self.height = height + self.width = width + + meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') + self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) + self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), + requires_grad=False) + + self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), + requires_grad=False) + + self.pix_coords = torch.unsqueeze(torch.stack( + [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) + self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) + self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), + requires_grad=False) + + def forward(self, depth, inv_K): + cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) + cam_points = depth.view(self.batch_size, 1, -1) * cam_points + cam_points = torch.cat([cam_points, self.ones], 1) + + return cam_points + + +class Project3D(nn.Module): + """Layer which projects 3D points into a camera with intrinsics K and at position T + """ + def __init__(self, batch_size, height, width, eps=1e-7): + super(Project3D, self).__init__() + + self.batch_size = batch_size + self.height = height + self.width = width + self.eps = eps + + def forward(self, points, K, T): + P = torch.matmul(K, T)[:, :3, :] + + cam_points = torch.matmul(P, points) + + pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) + pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) + pix_coords = pix_coords.permute(0, 2, 3, 1) + pix_coords[..., 0] /= self.width - 1 + pix_coords[..., 1] /= self.height - 1 + pix_coords = (pix_coords - 0.5) * 2 + return pix_coords + + +def upsample(x): + """Upsample input tensor by a factor of 2 + """ + return F.interpolate(x, scale_factor=2, mode="nearest") + + +def get_smooth_loss(disp, img): + """Computes the smoothness loss for a disparity image + The color image is used for edge-aware smoothness + """ + grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) + grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) + + grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) + grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) + + grad_disp_x *= torch.exp(-grad_img_x) + grad_disp_y *= torch.exp(-grad_img_y) + + return grad_disp_x.mean() + grad_disp_y.mean() + + +class SSIM(nn.Module): + """Layer to compute the SSIM loss between a pair of images + """ + def __init__(self): + super(SSIM, self).__init__() + self.mu_x_pool = nn.AvgPool2d(3, 1) + self.mu_y_pool = nn.AvgPool2d(3, 1) + self.sig_x_pool = nn.AvgPool2d(3, 1) + self.sig_y_pool = nn.AvgPool2d(3, 1) + self.sig_xy_pool = nn.AvgPool2d(3, 1) + + self.refl = nn.ReflectionPad2d(1) + + self.C1 = 0.01 ** 2 + self.C2 = 0.03 ** 2 + + def forward(self, x, y): + x = self.refl(x) + y = self.refl(y) + + mu_x = self.mu_x_pool(x) + mu_y = self.mu_y_pool(y) + + sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 + sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 + sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y + + SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) + SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) + + return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) + + +def compute_depth_errors(gt, pred): + """Computation of error metrics between predicted and ground truth depths + """ + thresh = torch.max((gt / pred), (pred / gt)) + a1 = (thresh < 1.25 ).float().mean() + a2 = (thresh < 1.25 ** 2).float().mean() + a3 = (thresh < 1.25 ** 3).float().mean() + + rmse = (gt - pred) ** 2 + rmse = torch.sqrt(rmse.mean()) + + rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 + rmse_log = torch.sqrt(rmse_log.mean()) + + abs_rel = torch.mean(torch.abs(gt - pred) / gt) + + sq_rel = torch.mean((gt - pred) ** 2 / gt) + + return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 diff --git a/models/mono+stereo_640x192/model_weight_files.txt b/models/mono+stereo_640x192/model_weight_files.txt new file mode 100644 index 0000000..527643e --- /dev/null +++ b/models/mono+stereo_640x192/model_weight_files.txt @@ -0,0 +1 @@ +Put model weight files in this folder! \ No newline at end of file diff --git a/networks/__init__.py b/networks/__init__.py new file mode 100644 index 0000000..2386870 --- /dev/null +++ b/networks/__init__.py @@ -0,0 +1,4 @@ +from .resnet_encoder import ResnetEncoder +from .depth_decoder import DepthDecoder +from .pose_decoder import PoseDecoder +from .pose_cnn import PoseCNN diff --git a/networks/depth_decoder.py b/networks/depth_decoder.py new file mode 100644 index 0000000..498ec38 --- /dev/null +++ b/networks/depth_decoder.py @@ -0,0 +1,65 @@ +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +from __future__ import absolute_import, division, print_function + +import numpy as np +import torch +import torch.nn as nn + +from collections import OrderedDict +from layers import * + + +class DepthDecoder(nn.Module): + def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True): + super(DepthDecoder, self).__init__() + + self.num_output_channels = num_output_channels + self.use_skips = use_skips + self.upsample_mode = 'nearest' + self.scales = scales + + self.num_ch_enc = num_ch_enc + self.num_ch_dec = np.array([16, 32, 64, 128, 256]) + + # decoder + self.convs = OrderedDict() + for i in range(4, -1, -1): + # upconv_0 + num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] + num_ch_out = self.num_ch_dec[i] + self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) + + # upconv_1 + num_ch_in = self.num_ch_dec[i] + if self.use_skips and i > 0: + num_ch_in += self.num_ch_enc[i - 1] + num_ch_out = self.num_ch_dec[i] + self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) + + for s in self.scales: + self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) + + self.decoder = nn.ModuleList(list(self.convs.values())) + self.sigmoid = nn.Sigmoid() + + def forward(self, input_features): + self.outputs = {} + + # decoder + x = input_features[-1] + for i in range(4, -1, -1): + x = self.convs[("upconv", i, 0)](x) + x = [upsample(x)] + if self.use_skips and i > 0: + x += [input_features[i - 1]] + x = torch.cat(x, 1) + x = self.convs[("upconv", i, 1)](x) + if i in self.scales: + self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x)) + + return self.outputs diff --git a/networks/pose_cnn.py b/networks/pose_cnn.py new file mode 100644 index 0000000..16baec7 --- /dev/null +++ b/networks/pose_cnn.py @@ -0,0 +1,50 @@ +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +from __future__ import absolute_import, division, print_function + +import torch +import torch.nn as nn + + +class PoseCNN(nn.Module): + def __init__(self, num_input_frames): + super(PoseCNN, self).__init__() + + self.num_input_frames = num_input_frames + + self.convs = {} + self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3) + self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2) + self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1) + self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1) + self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1) + self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1) + self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1) + + self.pose_conv = nn.Conv2d(256, 6 * (num_input_frames - 1), 1) + + self.num_convs = len(self.convs) + + self.relu = nn.ReLU(True) + + self.net = nn.ModuleList(list(self.convs.values())) + + def forward(self, out): + + for i in range(self.num_convs): + out = self.convs[i](out) + out = self.relu(out) + + out = self.pose_conv(out) + out = out.mean(3).mean(2) + + out = 0.01 * out.view(-1, self.num_input_frames - 1, 1, 6) + + axisangle = out[..., :3] + translation = out[..., 3:] + + return axisangle, translation diff --git a/networks/pose_decoder.py b/networks/pose_decoder.py new file mode 100644 index 0000000..4b03b60 --- /dev/null +++ b/networks/pose_decoder.py @@ -0,0 +1,54 @@ +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +from __future__ import absolute_import, division, print_function + +import torch +import torch.nn as nn +from collections import OrderedDict + + +class PoseDecoder(nn.Module): + def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1): + super(PoseDecoder, self).__init__() + + self.num_ch_enc = num_ch_enc + self.num_input_features = num_input_features + + if num_frames_to_predict_for is None: + num_frames_to_predict_for = num_input_features - 1 + self.num_frames_to_predict_for = num_frames_to_predict_for + + self.convs = OrderedDict() + self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1) + self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1) + self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1) + self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1) + + self.relu = nn.ReLU() + + self.net = nn.ModuleList(list(self.convs.values())) + + def forward(self, input_features): + last_features = [f[-1] for f in input_features] + + cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features] + cat_features = torch.cat(cat_features, 1) + + out = cat_features + for i in range(3): + out = self.convs[("pose", i)](out) + if i != 2: + out = self.relu(out) + + out = out.mean(3).mean(2) + + out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6) + + axisangle = out[..., :3] + translation = out[..., 3:] + + return axisangle, translation diff --git a/networks/resnet_encoder.py b/networks/resnet_encoder.py new file mode 100644 index 0000000..9c94418 --- /dev/null +++ b/networks/resnet_encoder.py @@ -0,0 +1,98 @@ +# Copyright Niantic 2019. Patent Pending. All rights reserved. +# +# This software is licensed under the terms of the Monodepth2 licence +# which allows for non-commercial use only, the full terms of which are made +# available in the LICENSE file. + +from __future__ import absolute_import, division, print_function + +import numpy as np + +import torch +import torch.nn as nn +import torchvision.models as models +import torch.utils.model_zoo as model_zoo + + +class ResNetMultiImageInput(models.ResNet): + """Constructs a resnet model with varying number of input images. + Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + """ + def __init__(self, block, layers, num_classes=1000, num_input_images=1): + super(ResNetMultiImageInput, self).__init__(block, layers) + self.inplanes = 64 + self.conv1 = nn.Conv2d( + num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): + """Constructs a ResNet model. + Args: + num_layers (int): Number of resnet layers. Must be 18 or 50 + pretrained (bool): If True, returns a model pre-trained on ImageNet + num_input_images (int): Number of frames stacked as input + """ + assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" + blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] + block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] + model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) + + if pretrained: + loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) + loaded['conv1.weight'] = torch.cat( + [loaded['conv1.weight']] * num_input_images, 1) / num_input_images + model.load_state_dict(loaded) + return model + + +class ResnetEncoder(nn.Module): + """Pytorch module for a resnet encoder + """ + def __init__(self, num_layers, pretrained, num_input_images=1): + super(ResnetEncoder, self).__init__() + + self.num_ch_enc = np.array([64, 64, 128, 256, 512]) + + resnets = {18: models.resnet18, + 34: models.resnet34, + 50: models.resnet50, + 101: models.resnet101, + 152: models.resnet152} + + if num_layers not in resnets: + raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) + + if num_input_images > 1: + self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) + else: + self.encoder = resnets[num_layers](pretrained) + + if num_layers > 34: + self.num_ch_enc[1:] *= 4 + + def forward(self, input_image): + self.features = [] + x = (input_image - 0.45) / 0.225 + x = self.encoder.conv1(x) + x = self.encoder.bn1(x) + self.features.append(self.encoder.relu(x)) + self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) + self.features.append(self.encoder.layer2(self.features[-1])) + self.features.append(self.encoder.layer3(self.features[-1])) + self.features.append(self.encoder.layer4(self.features[-1])) + + return self.features