Skip to content

Commit

Permalink
Merge branch 'main' into leonwu0108/main
Browse files Browse the repository at this point in the history
hugoycj committed Jul 23, 2024
2 parents 6c298b0 + 15b3aa8 commit ee41a99
Showing 5 changed files with 44 additions and 14 deletions.
10 changes: 9 additions & 1 deletion scene/cameras.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@

class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid, principal_point_ndc,
image_name, uid, principal_point_ndc, normal=None,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
):
super(Camera, self).__init__()
@@ -37,6 +37,14 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
self.data_device = torch.device("cuda")

self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
if normal is not None:
self.normal = normal.to(self.data_device)
normal_norm = torch.norm(self.normal, dim=0, keepdim=True)
self.normal_mask = ~((normal_norm > 1.1) | (normal_norm < 0.9))
self.normal = self.normal / normal_norm
else:
self.normal = None
self.normal_mask = None
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]

2 changes: 1 addition & 1 deletion submodules/diff-surfel-rasterization
26 changes: 24 additions & 2 deletions utils/camera_utils.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,9 @@
import numpy as np
from utils.general_utils import PILtoTorch
from utils.graphics_utils import fov2focal

from PIL import Image
import os
import torch.nn.functional as F
WARNED = False

def loadCam(args, id, cam_info, resolution_scale):
@@ -48,9 +50,29 @@ def loadCam(args, id, cam_info, resolution_scale):
loaded_mask = None
gt_image = resized_image_rgb

if args.w_normal_prior:
import torch
# normal_path = cam_info.image_path.replace('images_4', args.w_normal_prior)
normal_path = os.path.join(os.path.dirname(os.path.dirname(cam_info.image_path)), args.w_normal_prior, os.path.basename(cam_info.image_path))
if os.path.exists(normal_path[:-4]+ '.npy'):
_normal = torch.tensor(np.load(normal_path[:-4]+ '.npy'))
_normal = - (_normal * 2 - 1)
resized_normal = F.interpolate(_normal.unsqueeze(0), size=resolution[::-1], mode='bicubic')
_normal = resized_normal.squeeze(0)
else:
_normal = Image.open(normal_path[:-4]+ '.png')
resized_normal = PILtoTorch(_normal, resolution)
resized_normal = resized_normal[:3]
_normal = - (resized_normal * 2 - 1)
# normalize normal
_normal = _normal.permute(1, 2, 0) @ (torch.tensor(np.linalg.inv(cam_info.R)).float())
_normal = _normal.permute(2, 0, 1)
else:
_normal = None

return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask,
image=gt_image, normal=_normal, gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name, uid=id,
principal_point_ndc=cam_info.principal_point_ndc,
data_device=args.data_device)
6 changes: 3 additions & 3 deletions utils/mesh_utils.py
Original file line number Diff line number Diff line change
@@ -116,8 +116,8 @@ def reconstruction(self, viewpoint_stack):
# self.normals.append(normal.cpu())
# self.depth_normals.append(depth_normal.cpu())

self.rgbmaps = torch.stack(self.rgbmaps, dim=0)
self.depthmaps = torch.stack(self.depthmaps, dim=0)
# self.rgbmaps = torch.stack(self.rgbmaps, dim=0)
# self.depthmaps = torch.stack(self.depthmaps, dim=0)
# self.alphamaps = torch.stack(self.alphamaps, dim=0)
# self.depth_normals = torch.stack(self.depth_normals, dim=0)
self.estimate_bounding_sphere()
@@ -292,4 +292,4 @@ def export_image(self, path):
save_img_u8(self.rgbmaps[idx].permute(1,2,0).cpu().numpy(), os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
save_img_f32(self.depthmaps[idx][0].cpu().numpy(), os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + ".tiff"))
# save_img_u8(self.normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'normal_{0:05d}'.format(idx) + ".png"))
# save_img_u8(self.depth_normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'depth_normal_{0:05d}'.format(idx) + ".png"))
# save_img_u8(self.depth_normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'depth_normal_{0:05d}'.format(idx) + ".png"))
14 changes: 7 additions & 7 deletions utils/point_utils.py
Original file line number Diff line number Diff line change
@@ -9,13 +9,13 @@
def depths_to_points(view, depthmap):
c2w = (view.world_view_transform.T).inverse()
W, H = view.image_width, view.image_height
fx = W / (2 * math.tan(view.FoVx / 2.))
fy = H / (2 * math.tan(view.FoVy / 2.))
intrins = torch.tensor(
[[fx, 0., W/2.],
[0., fy, H/2.],
[0., 0., 1.0]]
).float().cuda()
ndc2pix = torch.tensor([
[W / 2, 0, 0, (W) / 2],
[0, H / 2, 0, (H) / 2],
[0, 0, 0, 1]]).float().cuda().T
projection_matrix = c2w.T @ view.full_proj_transform
intrins = (projection_matrix @ ndc2pix)[:3,:3].T

grid_x, grid_y = torch.meshgrid(torch.arange(W, device='cuda').float(), torch.arange(H, device='cuda').float(), indexing='xy')
points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(-1, 3)
rays_d = points @ intrins.inverse().T @ c2w[:3,:3].T

0 comments on commit ee41a99

Please sign in to comment.