Skip to content

Commit

Permalink
Gaussian splatting support for Aria (#2785)
Browse files Browse the repository at this point in the history
* Gaussian splatting support for Aria

* Respect masks in splatting loss function
  • Loading branch information
brentyi authored Jan 19, 2024
1 parent c27b5cb commit 6389cac
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 126 deletions.
2 changes: 1 addition & 1 deletion nerfstudio/cameras/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def fisheye624_unproject_helper(uv, params, max_iters: int = 5):
function so this solves an optimization problem using Newton's method to get
the inverse.
Inputs:
uv: BxNx3 tensor of 2D pixels to be projected
uv: BxNx2 tensor of 2D pixels to be unprojected
params: Bx16 tensor of Fisheye624 parameters formatted like this:
[f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
or Bx15 tensor of Fisheye624 parameters formatted like this:
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/cameras/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def _compute_rays_for_vr180(

assert distortion_params is not None
masked_coords = pcoord_stack[coord_mask, :]
# The fisheye unprojection does not rely on planar/pinhold unprojection, thus the method needs
# The fisheye unprojection does not rely on planar/pinhole unprojection, thus the method needs
# to access the focal length and principle points directly.
camera_params = torch.cat(
[
Expand Down
276 changes: 166 additions & 110 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.nn import Parameter
from tqdm import tqdm

from nerfstudio.cameras.camera_utils import fisheye624_project, fisheye624_unproject_helper
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion
from nerfstudio.data.datamanagers.base_datamanager import DataManager, DataManagerConfig, TDataset
Expand Down Expand Up @@ -135,70 +136,20 @@ def cache_images(self, cache_images_option):
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
if "depth_image" in data:
data["depth_image"] = data["depth_image"][y : y + h, x : x + w]
# update the width, height
self.train_dataset.cameras.width[i] = w
self.train_dataset.cameras.height[i] = h
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
data["mask"] = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
data["mask"] = torch.from_numpy(mask).bool()
K = newK
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

K, image, mask = _undistort_image(camera, distortion_params, data, image, K)
data["image"] = torch.from_numpy(image)
if mask is not None:
data["mask"] = mask

cached_train.append(data)

self.train_dataset.cameras.fx[i] = float(K[0, 0])
self.train_dataset.cameras.fy[i] = float(K[1, 1])
self.train_dataset.cameras.cx[i] = float(K[0, 2])
self.train_dataset.cameras.cy[i] = float(K[1, 2])
self.train_dataset.cameras.width[i] = image.shape[1]
self.train_dataset.cameras.height[i] = image.shape[0]

CONSOLE.log("Caching / undistorting eval images")
for i in tqdm(range(len(self.eval_dataset)), leave=False):
Expand All @@ -210,68 +161,20 @@ def cache_images(self, cache_images_option):
continue
distortion_params = camera.distortion_params.numpy()
image = data["image"].numpy()
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
# update the width, height
self.eval_dataset.cameras.width[i] = w
self.eval_dataset.cameras.height[i] = h
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
data["mask"] = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
data["mask"] = torch.from_numpy(mask).bool()
K = newK
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

K, image, mask = _undistort_image(camera, distortion_params, data, image, K)
data["image"] = torch.from_numpy(image)
if mask is not None:
data["mask"] = mask

cached_eval.append(data)

self.eval_dataset.cameras.fx[i] = float(K[0, 0])
self.eval_dataset.cameras.fy[i] = float(K[1, 1])
self.eval_dataset.cameras.cx[i] = float(K[0, 2])
self.eval_dataset.cameras.cy[i] = float(K[1, 2])
self.eval_dataset.cameras.width[i] = image.shape[1]
self.eval_dataset.cameras.height[i] = image.shape[0]

if cache_images_option == "gpu":
for cache in cached_train:
Expand Down Expand Up @@ -416,3 +319,156 @@ def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]:
assert len(self.eval_dataset.cameras.shape) == 1, "Assumes single batch dimension"
camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device)
return camera, data


def _undistort_image(
camera: Cameras, distortion_params: np.ndarray, data: dict, image: np.ndarray, K: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, Optional[torch.Tensor]]:
mask = None
if camera.camera_type.item() == CameraType.PERSPECTIVE.value:
distortion_params = np.array(
[
distortion_params[0],
distortion_params[1],
distortion_params[4],
distortion_params[5],
distortion_params[2],
distortion_params[3],
0,
0,
]
)
if np.any(distortion_params):
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK) # type: ignore
else:
newK = K
roi = 0, 0, image.shape[1], image.shape[0]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
if "depth_image" in data:
data["depth_image"] = data["depth_image"][y : y + h, x : x + w]
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
if np.any(distortion_params):
mask = cv2.undistort(mask, K, distortion_params, None, newK) # type: ignore
mask = mask[y : y + h, x : x + w]
mask = torch.from_numpy(mask).bool()
K = newK

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR)
if "mask" in data:
mask = data["mask"].numpy()
mask = mask.astype(np.uint8) * 255
mask = cv2.fisheye.undistortImage(mask, K, distortion_params, None, newK)
mask = torch.from_numpy(mask).bool()
K = newK
elif camera.camera_type.item() == CameraType.FISHEYE624.value:
fisheye624_params = torch.cat(
[camera.fx, camera.fy, camera.cx, camera.cy, torch.from_numpy(distortion_params)], dim=0
)
assert fisheye624_params.shape == (16,)
assert (
"mask" not in data
and camera.metadata is not None
and "fisheye_crop_radius" in camera.metadata
and isinstance(camera.metadata["fisheye_crop_radius"], float)
)
fisheye_crop_radius = camera.metadata["fisheye_crop_radius"]

# Approximate the FOV of the unmasked region of the camera.
upper, lower, left, right = fisheye624_unproject_helper(
torch.tensor(
[
[camera.cx, camera.cy - fisheye_crop_radius],
[camera.cx, camera.cy + fisheye_crop_radius],
[camera.cx - fisheye_crop_radius, camera.cy],
[camera.cx + fisheye_crop_radius, camera.cy],
],
dtype=torch.float32,
)[None],
params=fisheye624_params[None],
).squeeze(dim=0)
fov_radians = torch.max(
torch.acos(torch.sum(upper * lower / torch.linalg.norm(upper) / torch.linalg.norm(lower))),
torch.acos(torch.sum(left * right / torch.linalg.norm(left) / torch.linalg.norm(right))),
)

# Heuristics to determine parameters of an undistorted image.
undist_h = int(fisheye_crop_radius * 2)
undist_w = int(fisheye_crop_radius * 2)
undistort_focal = undist_h / (2 * torch.tan(fov_radians / 2.0))
undist_K = torch.eye(3)
undist_K[0, 0] = undistort_focal # fx
undist_K[1, 1] = undistort_focal # fy
undist_K[0, 2] = (undist_w - 1) / 2.0 # cx; for a 1x1 image, center should be at (0, 0).
undist_K[1, 2] = (undist_h - 1) / 2.0 # cy

# Undistorted 2D coordinates -> rays -> reproject to distorted UV coordinates.
undist_uv_homog = torch.stack(
[
*torch.meshgrid(
torch.arange(undist_w, dtype=torch.float32),
torch.arange(undist_h, dtype=torch.float32),
),
torch.ones((undist_w, undist_h), dtype=torch.float32),
],
dim=-1,
)
assert undist_uv_homog.shape == (undist_w, undist_h, 3)
dist_uv = (
fisheye624_project(
xyz=(
torch.einsum(
"ij,bj->bi",
torch.linalg.inv(undist_K),
undist_uv_homog.reshape((undist_w * undist_h, 3)),
)[None]
),
params=fisheye624_params[None, :],
)
.reshape((undist_w, undist_h, 2))
.numpy()
)
map1 = dist_uv[..., 1]
map2 = dist_uv[..., 0]

# Use correspondence to undistort image.
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR)

# Compute undistorted mask as well.
dist_h = camera.height.item()
dist_w = camera.width.item()
mask = np.mgrid[:dist_h, :dist_w]
mask[0, ...] -= dist_h // 2
mask[1, ...] -= dist_w // 2
mask = np.linalg.norm(mask, axis=0) < fisheye_crop_radius
mask = torch.from_numpy(
cv2.remap(
mask.astype(np.uint8) * 255,
map1,
map2,
interpolation=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0,
)
/ 255.0
).bool()
K = undist_K.numpy()
else:
raise NotImplementedError("Only perspective and fisheye cameras are supported")

return K, image, mask
Loading

0 comments on commit 6389cac

Please sign in to comment.