Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
TruongKhang committed Dec 8, 2021
1 parent d27c374 commit 633f258
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
Empty file added calib/nmail3.txt
Empty file.
12 changes: 11 additions & 1 deletion droid_slam/droid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from droid_frontend import DroidFrontend
from droid_backend import DroidBackend
from trajectory_filler import PoseTrajectoryFiller
from cdsmvsnet import CDSMVSNet

from collections import OrderedDict
from torch.multiprocessing import Process
Expand All @@ -20,14 +21,23 @@ def __init__(self, args):
self.args = args
self.disable_vis = args.disable_vis

# dense depth prediction
self.mvsnet = CDSMVSNet(refine=True, ndepths=(64, 32, 8), depth_interals_ratio=(4, 2, 1))
mvsnet_ckpt = torch.load(args.mvsnet_ckpt)
state_dict = OrderedDict([
(k.replace("module.", ""), v) for (k, v) in torch.load(mvsnet_ckpt["state_dict"]).items()
])
self.mvsnet.load_state_dict(state_dict)
self.mvsnet.to("cuda:0").eval()

# store images, depth, poses, intrinsics (shared between processes)
self.video = DepthVideo(args.image_size, args.buffer, stereo=args.stereo)

# filter incoming frames so that there is enough motion
self.filterx = MotionFilter(self.net, self.video, thresh=args.filter_thresh)

# frontend process
self.frontend = DroidFrontend(self.net, self.video, self.args)
self.frontend = DroidFrontend(self.net, self.video, self.mvsnet, self.args)

# backend process
self.backend = DroidBackend(self.net, self.video, self.args)
Expand Down
16 changes: 14 additions & 2 deletions droid_slam/droid_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

from lietorch import SE3
from factor_graph import FactorGraph
from cdsmvsnet import CDSMVSNet


class DroidFrontend:
def __init__(self, net, video, args):
def __init__(self, net, video, mvsnet, args):
self.video = video
self.update_op = net.update
self.graph = FactorGraph(video, net.update, max_factors=48)
self.mvsnet = CDS
self.mvsnet = mvsnet

# local optimization window
self.t0 = 0
Expand Down Expand Up @@ -66,6 +67,17 @@ def __update(self):
for itr in range(self.iters2):
self.graph.update(None, None, use_inactive=True)

# refine depths
if self.mvsnet is not None:
ref_id, src_ids = self.t1 - 3, [self.t1-5, self.t1-4, self.t1-2, self.t1-1]
img_ids = [ref_id] + src_ids
intrinsics = self.video.intrinsics[img_ids]
poses = SE3(self.video.poses[img_ids])
ref_disp = self.video.disp[ref_id]
val_depths = ref_disp[(ref_disp > 0.001) & (ref_disp < 1000)]
min_d, max_d = val_depths.min(), val_depths.max()
d_interval = (max_d - min_d) / 192
depth_values = torch.arange(min_d, max_d, d_interval).unsqueeze(0)
# set pose for next itration
self.video.poses[self.t1] = self.video.poses[self.t1-1]
self.video.disps[self.t1] = self.video.disps[self.t1-1].mean()
Expand Down

0 comments on commit 633f258

Please sign in to comment.