Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gsplat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .cuda._torch_impl_2dgs import accumulate_2dgs
from .cuda._wrapper import (
CameraModel,
ExternalDistortionModelMeta,
RollingShutterType,
fully_fused_projection,
fully_fused_projection_2dgs,
Expand Down Expand Up @@ -64,6 +65,7 @@
"MCMCStrategy",
"Strategy",
"CameraModel",
"ExternalDistortionModelMeta",
"RasterizeMode",
"RenderMode",
"rasterization",
Expand Down
65 changes: 65 additions & 0 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from gsplat._helper import assert_shape

CameraModel = Literal["pinhole", "ortho", "fisheye", "ftheta"]
ExternalDistortionModelMeta = Literal["bivariate-windshield"]


def _make_lazy_cuda_func(name: str) -> Callable:
Expand Down Expand Up @@ -116,6 +117,49 @@ def to_cpp(self) -> Any:
return p


class ExternalDistortionReferencePolynomial(Enum):
FORWARD = 1
BACKWARD = 2

def to_cpp(self) -> Any:
return _make_lazy_cuda_obj(f"ExternalDistortionReferencePolynomial.{self.name}")


@dataclass
class BivariateWindshieldModelParameters:
MAX_ORDER = 5
MAX_COEFFS = 21

reference_poly: ExternalDistortionReferencePolynomial
horizontal_poly: Tensor # [..., (N + 1) * (N + 2) / 2]
vertical_poly: Tensor # [..., (N + 1) * (N + 2) / 2]
horizontal_poly_inverse: Tensor # [..., (N + 1) * (N + 2) / 2]
vertical_poly_inverse: Tensor # [..., (N + 1) * (N + 2) / 2]

def __init__(
self,
reference_poly: ExternalDistortionReferencePolynomial,
horizontal_poly: Tensor,
vertical_poly: Tensor,
horizontal_poly_inverse: Tensor,
vertical_poly_inverse: Tensor,
):
self.reference_poly = reference_poly
self.horizontal_poly = horizontal_poly
self.vertical_poly = vertical_poly
self.horizontal_poly_inverse = horizontal_poly_inverse
self.vertical_poly_inverse = vertical_poly_inverse

def to_cpp(self) -> Any:
p = _make_lazy_cuda_obj("BivariateWindshieldModelParameters")()
p.reference_poly = self.reference_poly.to_cpp()
p.horizontal_poly = self.horizontal_poly.contiguous()
p.vertical_poly = self.vertical_poly.contiguous()
p.horizontal_poly_inverse = self.horizontal_poly_inverse.contiguous()
p.vertical_poly_inverse = self.vertical_poly_inverse.contiguous()
return p


class FThetaPolynomialType(Enum):
PIXELDIST_TO_ANGLE = 0
ANGLE_TO_PIXELDIST = 1
Expand Down Expand Up @@ -752,6 +796,7 @@ def rasterize_to_pixels_eval3d(
tangential_coeffs: Optional[Tensor] = None, # [..., C, 2]
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4]
ftheta_coeffs: Optional[FThetaCameraDistortionParameters] = None,
external_distortion_coeffs: Optional[BivariateWindshieldModelParameters] = None,
# rolling shutter
rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL,
viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4]
Expand Down Expand Up @@ -793,6 +838,7 @@ def rasterize_to_pixels_eval3d(
tangential_coeffs=tangential_coeffs,
thin_prism_coeffs=thin_prism_coeffs,
ftheta_coeffs=ftheta_coeffs,
external_distortion_coeffs=external_distortion_coeffs,
rolling_shutter=rolling_shutter,
viewmats_rs=viewmats_rs,
return_sample_counts=False,
Expand Down Expand Up @@ -825,6 +871,7 @@ def rasterize_to_pixels_eval3d_extra(
tangential_coeffs: Optional[Tensor] = None, # [..., C, 2]
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4]
ftheta_coeffs: Optional[FThetaCameraDistortionParameters] = None,
external_distortion_coeffs: Optional[BivariateWindshieldModelParameters] = None,
# rolling shutter
rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL,
viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4]
Expand Down Expand Up @@ -992,6 +1039,7 @@ def rasterize_to_pixels_eval3d_extra(
tangential_coeffs.contiguous() if tangential_coeffs is not None else None,
thin_prism_coeffs.contiguous() if thin_prism_coeffs is not None else None,
ftheta_coeffs,
external_distortion_coeffs,
# rolling shutter
rolling_shutter,
viewmats_rs.contiguous() if viewmats_rs is not None else None,
Expand Down Expand Up @@ -1353,6 +1401,7 @@ def fully_fused_projection_with_ut(
tangential_coeffs: Optional[Tensor] = None, # [..., C, 2]
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4]
ftheta_coeffs: Optional[FThetaCameraDistortionParameters] = None,
external_distortion_coeffs: Optional[BivariateWindshieldModelParameters] = None,
# rolling shutter
rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL,
viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4]
Expand Down Expand Up @@ -1425,6 +1474,11 @@ def fully_fused_projection_with_ut(
if ftheta_coeffs is not None
else FThetaCameraDistortionParameters.to_cpp_default()
),
(
external_distortion_coeffs.to_cpp()
if external_distortion_coeffs is not None
else None
),
)
if not calc_compensations:
compensations = None
Expand Down Expand Up @@ -1589,6 +1643,7 @@ def forward(
tangential_coeffs: Optional[Tensor] = None, # [..., C, 2]
thin_prism_coeffs: Optional[Tensor] = None, # [..., C, 4]
ftheta_coeffs: Optional[FThetaCameraDistortionParameters] = None,
external_distortion_coeffs: Optional[BivariateWindshieldModelParameters] = None,
# rolling shutter
rolling_shutter: RollingShutterType = RollingShutterType.GLOBAL,
viewmats_rs: Optional[Tensor] = None, # [..., C, 4, 4]
Expand All @@ -1607,6 +1662,11 @@ def forward(
else FThetaCameraDistortionParameters.to_cpp_default()
)

external_distortion_coeffs = (
external_distortion_coeffs.to_cpp()
if external_distortion_coeffs is not None
else None
)
# Extract batch_dims for sample_counts allocation
batch_dims = means.shape[:-2]
C = viewmats.size(-3)
Expand Down Expand Up @@ -1654,6 +1714,7 @@ def forward(
tangential_coeffs,
thin_prism_coeffs,
ftheta_coeffs,
external_distortion_coeffs,
isect_offsets,
flatten_ids,
use_hit_distance,
Expand Down Expand Up @@ -1688,6 +1749,7 @@ def forward(
ctx.camera_model_type = camera_model_type
ctx.tile_size = tile_size
ctx.ftheta_coeffs = ftheta_coeffs
ctx.external_distortion_coeffs = external_distortion_coeffs
ctx.use_hit_distance = use_hit_distance

return render_colors, render_alphas, last_ids, sample_counts, render_normals
Expand Down Expand Up @@ -1730,6 +1792,7 @@ def backward(
camera_model_type = ctx.camera_model_type
tile_size = ctx.tile_size
ftheta_coeffs = ctx.ftheta_coeffs
external_distortion_coeffs = ctx.external_distortion_coeffs
use_hit_distance = ctx.use_hit_distance

(
Expand Down Expand Up @@ -1761,6 +1824,7 @@ def backward(
tangential_coeffs,
thin_prism_coeffs,
ftheta_coeffs,
external_distortion_coeffs,
isect_offsets,
flatten_ids,
use_hit_distance,
Expand Down Expand Up @@ -1804,6 +1868,7 @@ def backward(
None, # tangential_coeffs
None, # thin_prism_coeffs
None, # ftheta_coeffs
None, # external_distortion_coeffs
None, # rolling_shutter
None, # viewmats_rs
None, # return_sample_counts (flag, no gradient)
Expand Down
11 changes: 10 additions & 1 deletion gsplat/cuda/csrc/Projection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,8 @@ projection_ut_3dgs_fused(
const at::optional<at::Tensor> radial_coeffs, // [..., C, 6] or [..., C, 4] optional
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
const FThetaCameraDistortionParameters ftheta_coeffs // shared parameters for all cameras
const FThetaCameraDistortionParameters ftheta_coeffs, // shared parameters for all cameras
const std::optional<extdist::BivariateWindshieldModelParameters> external_distortion_params // external distortion parameters
) {
DEVICE_GUARD(means);
CHECK_INPUT(means);
Expand All @@ -985,6 +986,13 @@ projection_ut_3dgs_fused(
CHECK_INPUT(thin_prism_coeffs.value());
}

if (external_distortion_params.has_value()) {
CHECK_INPUT(external_distortion_params->horizontal_poly);
CHECK_INPUT(external_distortion_params->vertical_poly);
CHECK_INPUT(external_distortion_params->horizontal_poly_inverse); // Could be omitted for projection
CHECK_INPUT(external_distortion_params->vertical_poly_inverse); // Could be omitted for projection
}

at::DimVector batch_dims(means.sizes().slice(0, means.dim() - 2));
uint32_t N = means.size(-2); // number of gaussians
uint32_t C = Ks.size(-3); // number of cameras
Expand Down Expand Up @@ -1038,6 +1046,7 @@ projection_ut_3dgs_fused(
tangential_coeffs,
thin_prism_coeffs,
ftheta_coeffs,
external_distortion_params,
// outputs
radii,
means2d,
Expand Down
3 changes: 3 additions & 0 deletions gsplat/cuda/csrc/Projection.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
#pragma once

#include <cstdint>

#include "Cameras.h"
#include "ExternalDistortion.h"

namespace at {
class Tensor;
Expand Down Expand Up @@ -291,6 +293,7 @@ void launch_projection_ut_3dgs_fused_kernel(
const at::optional<at::Tensor> tangential_coeffs, // [C, 2] optional
const at::optional<at::Tensor> thin_prism_coeffs, // [C, 4] optional
const FThetaCameraDistortionParameters ftheta_coeffs, // shared parameters for all cameras
const std::optional<extdist::BivariateWindshieldModelParameters> external_distortion_params, // external distortion parameters
// outputs
at::Tensor radii, // [C, N, 2]
at::Tensor means2d, // [C, N, 2]
Expand Down
18 changes: 18 additions & 0 deletions gsplat/cuda/csrc/ProjectionUT3DGSFused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
#include <ATen/cuda/Atomic.cuh>
#include <c10/cuda/CUDAStream.h>
#include <cooperative_groups.h>
#include <cuda/std/optional>

#include "Common.h"
#include "ExternalDistortion.cuh"
#include "Projection.h"
#include "Utils.cuh"
#include "Cameras.cuh"
Expand Down Expand Up @@ -62,6 +64,7 @@ __global__ void projection_ut_3dgs_fused_kernel(
const scalar_t *__restrict__ tangential_coeffs, // [B, C, 2] optional
const scalar_t *__restrict__ thin_prism_coeffs, // [B, C, 4] optional
const FThetaCameraDistortionParameters ftheta_coeffs, // shared parameters for all cameras
const cuda::std::optional<extdist::BivariateWindshieldModelDeviceParams> external_distortion_device_params, // external distortion parameters
// outputs
int32_t *__restrict__ radii, // [B, C, N, 2]
scalar_t *__restrict__ means2d, // [B, C, N, 2]
Expand Down Expand Up @@ -118,6 +121,8 @@ __global__ void projection_ut_3dgs_fused_kernel(
cm_params.shutter_type = rs_type;
cm_params.principal_point = { principal_point.x, principal_point.y };
cm_params.focal_length = { focal_length.x, focal_length.y };
cm_params.external_distortion_params = external_distortion_device_params.has_value() ?
&external_distortion_device_params.value() : nullptr;
PerfectPinholeCameraModel camera_model(cm_params);
image_gaussian_return =
world_gaussian_to_image_gaussian_unscented_transform_shutter_pose(
Expand All @@ -137,6 +142,8 @@ __global__ void projection_ut_3dgs_fused_kernel(
if (thin_prism_coeffs != nullptr) {
cm_params.thin_prism_coeffs = make_array<float, 4>(thin_prism_coeffs + bid * C * 4 + cid * 4);
}
cm_params.external_distortion_params = external_distortion_device_params.has_value() ?
&external_distortion_device_params.value() : nullptr;
OpenCVPinholeCameraModel camera_model(cm_params);
image_gaussian_return =
world_gaussian_to_image_gaussian_unscented_transform_shutter_pose(
Expand All @@ -151,6 +158,8 @@ __global__ void projection_ut_3dgs_fused_kernel(
if (radial_coeffs != nullptr) {
cm_params.radial_coeffs = make_array<float, 4>(radial_coeffs + bid * C * 4 + cid * 4);
}
cm_params.external_distortion_params = external_distortion_device_params.has_value() ?
&external_distortion_device_params.value() : nullptr;
OpenCVFisheyeCameraModel camera_model(cm_params);
image_gaussian_return =
world_gaussian_to_image_gaussian_unscented_transform_shutter_pose(
Expand All @@ -162,6 +171,8 @@ __global__ void projection_ut_3dgs_fused_kernel(
cm_params.shutter_type = rs_type;
cm_params.principal_point = { principal_point.x, principal_point.y };
cm_params.dist = ftheta_coeffs;
cm_params.external_distortion_params = external_distortion_device_params.has_value() ?
&external_distortion_device_params.value() : nullptr;
FThetaCameraModel camera_model(cm_params);
image_gaussian_return =
world_gaussian_to_image_gaussian_unscented_transform_shutter_pose(
Expand Down Expand Up @@ -268,6 +279,7 @@ void launch_projection_ut_3dgs_fused_kernel(
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
const FThetaCameraDistortionParameters ftheta_coeffs, // shared parameters for all cameras
const std::optional<extdist::BivariateWindshieldModelParameters> external_distortion_params, // external distortion parameters
// outputs
at::Tensor radii, // [..., C, N, 2]
at::Tensor means2d, // [..., C, N, 2]
Expand All @@ -289,6 +301,11 @@ void launch_projection_ut_3dgs_fused_kernel(
return;
}

cuda::std::optional<extdist::BivariateWindshieldModelDeviceParams> external_distortion_device_params = cuda::std::nullopt;
if (external_distortion_params.has_value()) {
external_distortion_device_params = extdist::BivariateWindshieldModelDeviceParams(external_distortion_params.value());
}

projection_ut_3dgs_fused_kernel<float>
<<<grid,
threads,
Expand Down Expand Up @@ -325,6 +342,7 @@ void launch_projection_ut_3dgs_fused_kernel(
? thin_prism_coeffs.value().data_ptr<float>()
: nullptr,
ftheta_coeffs,
external_distortion_device_params,
radii.data_ptr<int32_t>(),
means2d.data_ptr<float>(),
depths.data_ptr<float>(),
Expand Down
19 changes: 19 additions & 0 deletions gsplat/cuda/csrc/Rasterization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_from_world_3d
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
const FThetaCameraDistortionParameters ftheta_coeffs, // shared parameters for all cameras
const std::optional<extdist::BivariateWindshieldModelParameters> external_distortion_params,
// intersections
const at::Tensor tile_offsets, // [..., C, tile_height, tile_width]
const at::Tensor flatten_ids, // [n_isects]
Expand All @@ -698,6 +699,14 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_from_world_3d
if (masks.has_value()) {
CHECK_INPUT(masks.value());
}

if (external_distortion_params.has_value()) {
CHECK_INPUT(external_distortion_params->horizontal_poly);
CHECK_INPUT(external_distortion_params->vertical_poly);
CHECK_INPUT(external_distortion_params->horizontal_poly_inverse);
CHECK_INPUT(external_distortion_params->vertical_poly_inverse);
}

if (sample_counts.has_value()) {
CHECK_INPUT(sample_counts.value());
}
Expand Down Expand Up @@ -745,6 +754,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> rasterize_to_pixels_from_world_3d
tangential_coeffs, \
thin_prism_coeffs, \
ftheta_coeffs, \
external_distortion_params, \
tile_offsets, \
flatten_ids, \
use_hit_distance, \
Expand Down Expand Up @@ -797,6 +807,7 @@ rasterize_to_pixels_from_world_3dgs_bwd(
const at::optional<at::Tensor> tangential_coeffs, // [..., C, 2] optional
const at::optional<at::Tensor> thin_prism_coeffs, // [..., C, 4] optional
const FThetaCameraDistortionParameters ftheta_coeffs, // shared parameters for all cameras
const std::optional<extdist::BivariateWindshieldModelParameters> external_distortion_params,
// intersections
const at::Tensor tile_offsets, // [..., C, tile_height, tile_width]
const at::Tensor flatten_ids, // [n_isects]
Expand Down Expand Up @@ -834,6 +845,13 @@ rasterize_to_pixels_from_world_3dgs_bwd(
CHECK_INPUT(v_render_normals.value());
}

if (external_distortion_params.has_value()) {
CHECK_INPUT(external_distortion_params->horizontal_poly);
CHECK_INPUT(external_distortion_params->vertical_poly);
CHECK_INPUT(external_distortion_params->horizontal_poly_inverse);
CHECK_INPUT(external_distortion_params->vertical_poly_inverse);
}

uint32_t channels = colors.size(-1);

at::Tensor v_means = at::zeros_like(means);
Expand Down Expand Up @@ -867,6 +885,7 @@ rasterize_to_pixels_from_world_3dgs_bwd(
tangential_coeffs, \
thin_prism_coeffs, \
ftheta_coeffs, \
external_distortion_params, \
tile_offsets, \
flatten_ids, \
use_hit_distance, \
Expand Down
Loading
Loading