Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify data loaders and eval scripts; allow arbitrary image size #57

Merged
merged 13 commits into from
Dec 2, 2021
Prev Previous commit
Next Next commit
Fix bug in differentiable warping and store new checkpoints
  • Loading branch information
Antonios Matakos committed Oct 30, 2021
commit 9eb104ecdf44c81171ea9982a244d0e95936f2a8
Binary file removed checkpoints/model_000007.ckpt
Binary file not shown.
Binary file added checkpoints/module_000007.pt
Binary file not shown.
Binary file added checkpoints/params_000007.ckpt
Binary file not shown.
34 changes: 14 additions & 20 deletions models/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,16 @@ def differentiable_warping(
"""Differentiable homography-based warping, implemented in Pytorch.

Args:
src_fea: [B, C, H, W] source features, for each source view in batch
src_fea: [B, C, Hin, Win] source features, for each source view in batch
src_proj: [B, 4, 4] source camera projection matrix, for each source view in batch
ref_proj: [B, 4, 4] reference camera projection matrix, for each ref view in batch
depth_samples: [B, Ndepth, H, W] virtual depth layers
depth_samples: [B, Ndepth, Hout, Wout] virtual depth layers
Returns:
warped_src_fea: [B, C, Ndepth, H, W] features on depths after perspective transformation
warped_src_fea: [B, C, Ndepth, Hout, Wout] features on depths after perspective transformation
"""

batch, channels, height, width = src_fea.shape
num_depth = depth_samples.shape[1]
batch, num_depth, height, width = depth_samples.shape
channels = src_fea.shape[1]

with torch.no_grad():
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
Expand All @@ -155,13 +155,10 @@ def differentiable_warping(
torch.arange(0, width, dtype=torch.float32, device=src_fea.device),
]
)
y, x = y.contiguous(), x.contiguous()
y, x = y.view(height * width), x.view(height * width)
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
y, x = y.contiguous().view(height * width), x.contiguous().view(height * width)
xyz = torch.unsqueeze(torch.stack((x, y, torch.ones_like(x))), 0).repeat(batch, 1, 1) # [B, 3, H*W]

rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_samples.view(
rot_depth_xyz = torch.matmul(rot, xyz).unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_samples.view(
batch, 1, num_depth, height * width
) # [B, 3, Ndepth, H*W]
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W]
Expand All @@ -170,21 +167,18 @@ def differentiable_warping(
proj_xyz[:, 0:1][negative_depth_mask] = float(width)
proj_xyz[:, 1:2][negative_depth_mask] = float(height)
proj_xyz[:, 2:3][negative_depth_mask] = 1.0
proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1 # [B, Ndepth, H*W]
proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2]
grid = proj_xy
grid = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
proj_x_normalized = grid[:, 0, :, :] / ((width - 1) / 2) - 1 # [B, Ndepth, H*W]
proj_y_normalized = grid[:, 1, :, :] / ((height - 1) / 2) - 1
grid = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2]

warped_src_fea = F.grid_sample(
return F.grid_sample(
src_fea,
grid.view(batch, num_depth * height, width, 2),
mode="bilinear",
padding_mode="zeros",
align_corners=True,
)

return warped_src_fea.view(batch, channels, num_depth, height, width)
).view(batch, channels, num_depth, height, width)


def depth_regression(p: torch.Tensor, depth_values: torch.Tensor) -> torch.Tensor:
Expand Down