diff --git a/calib/nmail3.txt b/calib/nmail3.txt new file mode 100644 index 00000000..e69de29b diff --git a/droid_slam/droid.py b/droid_slam/droid.py index fe99c5ca..dbb30934 100644 --- a/droid_slam/droid.py +++ b/droid_slam/droid.py @@ -8,6 +8,7 @@ from droid_frontend import DroidFrontend from droid_backend import DroidBackend from trajectory_filler import PoseTrajectoryFiller +from cdsmvsnet import CDSMVSNet from collections import OrderedDict from torch.multiprocessing import Process @@ -20,6 +21,15 @@ def __init__(self, args): self.args = args self.disable_vis = args.disable_vis + # dense depth prediction + self.mvsnet = CDSMVSNet(refine=True, ndepths=(64, 32, 8), depth_interals_ratio=(4, 2, 1)) + mvsnet_ckpt = torch.load(args.mvsnet_ckpt) + state_dict = OrderedDict([ + (k.replace("module.", ""), v) for (k, v) in torch.load(mvsnet_ckpt["state_dict"]).items() + ]) + self.mvsnet.load_state_dict(state_dict) + self.mvsnet.to("cuda:0").eval() + # store images, depth, poses, intrinsics (shared between processes) self.video = DepthVideo(args.image_size, args.buffer, stereo=args.stereo) @@ -27,7 +37,7 @@ def __init__(self, args): self.filterx = MotionFilter(self.net, self.video, thresh=args.filter_thresh) # frontend process - self.frontend = DroidFrontend(self.net, self.video, self.args) + self.frontend = DroidFrontend(self.net, self.video, self.mvsnet, self.args) # backend process self.backend = DroidBackend(self.net, self.video, self.args) diff --git a/droid_slam/droid_frontend.py b/droid_slam/droid_frontend.py index b69bc943..25393811 100644 --- a/droid_slam/droid_frontend.py +++ b/droid_slam/droid_frontend.py @@ -4,14 +4,15 @@ from lietorch import SE3 from factor_graph import FactorGraph +from cdsmvsnet import CDSMVSNet class DroidFrontend: - def __init__(self, net, video, args): + def __init__(self, net, video, mvsnet, args): self.video = video self.update_op = net.update self.graph = FactorGraph(video, net.update, max_factors=48) - self.mvsnet = CDS + self.mvsnet = mvsnet # local optimization window self.t0 = 0 @@ -66,6 +67,17 @@ def __update(self): for itr in range(self.iters2): self.graph.update(None, None, use_inactive=True) + # refine depths + if self.mvsnet is not None: + ref_id, src_ids = self.t1 - 3, [self.t1-5, self.t1-4, self.t1-2, self.t1-1] + img_ids = [ref_id] + src_ids + intrinsics = self.video.intrinsics[img_ids] + poses = SE3(self.video.poses[img_ids]) + ref_disp = self.video.disp[ref_id] + val_depths = ref_disp[(ref_disp > 0.001) & (ref_disp < 1000)] + min_d, max_d = val_depths.min(), val_depths.max() + d_interval = (max_d - min_d) / 192 + depth_values = torch.arange(min_d, max_d, d_interval).unsqueeze(0) # set pose for next itration self.video.poses[self.t1] = self.video.poses[self.t1-1] self.video.disps[self.t1] = self.video.disps[self.t1-1].mean()