Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add background gaussian splatting model #4

Merged
merged 9 commits into from
Oct 16, 2024
Next Next commit
feat: Add background gaussian
  • Loading branch information
hugoycj committed Oct 13, 2024
commit 465ba80c1911bc746013a4d5323cad944ce8d137
72 changes: 24 additions & 48 deletions gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,29 @@
from utils.sh_utils import eval_sh
from utils.point_utils import depth_to_normal

def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, record_transmittance=False):
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0,
override_color = None, record_transmittance=False, bg_gaussians=None):
"""
Render the scene.

Background tensor (bg_color) must be on GPU!
"""

if bg_gaussians is None:
means3D = pc.get_xyz
opacity = pc.get_opacity
scales = pc.get_scaling
rotations = pc.get_rotation
shs = pc.get_features
else:
means3D = torch.cat([pc.get_xyz, bg_gaussians.get_xyz])
opacity = torch.cat([pc.get_opacity, bg_gaussians.get_opacity])
scales = torch.cat([pc.get_scaling, bg_gaussians.get_scaling])
rotations = torch.cat([pc.get_rotation, bg_gaussians.get_rotation])
shs = torch.cat([pc.get_features, bg_gaussians.get_features])
num_fg_points = pc.get_xyz.shape[0]

# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = torch.zeros((pc.get_xyz.shape[0], 4), dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
screenspace_points = torch.zeros((means3D.shape[0], 4), dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
try:
screenspace_points.retain_grad()
except:
Expand Down Expand Up @@ -53,64 +67,26 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,

rasterizer = GaussianRasterizer(raster_settings=raster_settings)

means3D = pc.get_xyz
means2D = screenspace_points
opacity = pc.get_opacity

# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
# scaling / rotation by the rasterizer.
scales = None
rotations = None
cov3D_precomp = None
if pipe.compute_cov3D_python:
# currently don't support normal consistency loss if use precomputed covariance
splat2world = pc.get_covariance(scaling_modifier)
W, H = viewpoint_camera.image_width, viewpoint_camera.image_height
near, far = viewpoint_camera.znear, viewpoint_camera.zfar
ndc2pix = torch.tensor([
[W / 2, 0, 0, (W-1) / 2],
[0, H / 2, 0, (H-1) / 2],
[0, 0, far-near, near],
[0, 0, 0, 1]]).float().cuda().T
world2pix = viewpoint_camera.full_proj_transform @ ndc2pix
cov3D_precomp = (splat2world[:, [0,1,3]] @ world2pix[:,[0,1,3]]).permute(0,2,1).reshape(-1, 9) # column major
else:
scales = pc.get_scaling
rotations = pc.get_rotation

# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
pipe.convert_SHs_python = False
shs = None
colors_precomp = None
if override_color is None:
if pipe.convert_SHs_python:
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
else:
shs = pc.get_features
else:
colors_precomp = override_color


output = rasterizer(
means3D = means3D,
means2D = means2D,
shs = shs,
colors_precomp = colors_precomp,
colors_precomp = None,
opacities = opacity,
scales = scales,
rotations = rotations,
cov3D_precomp = cov3D_precomp
)

cov3D_precomp = None)

if record_transmittance:
rendered_image, radii, allmap, transmittance_avg, num_covered_pixels = output
transmittance_avg = transmittance_avg[:num_fg_points]
num_covered_pixels = num_covered_pixels[:num_fg_points]
else:
rendered_image, radii, allmap = output
transmittance_avg = num_covered_pixels = None
radii = radii[:num_fg_points]
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
rets = {"render": rendered_image,
Expand Down
7 changes: 5 additions & 2 deletions scene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
from utils.system_utils import searchForMaxIteration
from scene.dataset_readers import sceneLoadTypeCallbacks
from scene.gaussian_model import GaussianModel
from scene.gaussian_model import GaussianModel, BgGaussianModel
from scene.appearance_model import AppearanceModel
from arguments import ModelParams
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
Expand All @@ -23,13 +23,14 @@ class Scene:

gaussians : GaussianModel

def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
def __init__(self, args : ModelParams, gaussians : GaussianModel, bg_gaussians: BgGaussianModel = None, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
"""b
:param path: Path to colmap scene main folder.
"""
self.model_path = args.model_path
self.loaded_iter = None
self.gaussians = gaussians
self.bg_gaussians = bg_gaussians

if load_iteration:
if load_iteration == -1:
Expand Down Expand Up @@ -82,10 +83,12 @@ def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration
"point_cloud.ply"))
else:
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
self.bg_gaussians.load_ply('background_gs.ply')

def save(self, iteration):
point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
self.bg_gaussians.save_ply(os.path.join(point_cloud_path, "bg_point_cloud.ply"))

def getTrainCameras(self, scale=1.0):
return self.train_cameras[scale]
Expand Down
54 changes: 51 additions & 3 deletions scene/gaussian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def load_ply(self, path):
scales = np.zeros((xyz.shape[0], len(scale_names)))
for idx, attr_name in enumerate(scale_names):
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
scales = scales[:, :2]

rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
Expand Down Expand Up @@ -478,10 +479,10 @@ def split_big_points(self, max_screen_size):

def add_densification_stats(self, viewspace_point_tensor, update_filter, pixels):
if pixels is not None:
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter], dim=-1, keepdim=True) * pixels[update_filter].unsqueeze(-1)
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[:len(update_filter)][update_filter], dim=-1, keepdim=True) * pixels[update_filter].unsqueeze(-1)
self.denom[update_filter] += pixels[update_filter].unsqueeze(-1)
else:
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter], dim=-1, keepdim=True)
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[:len(update_filter)][update_filter], dim=-1, keepdim=True)
self.denom[update_filter] += 1


Expand Down Expand Up @@ -542,4 +543,51 @@ def densify_from_depth_propagation(self, viewpoint_cam, propagated_depth, propag
new_opacity = nn.Parameter(opacities.requires_grad_(True))

#update gaussians
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)

class BgGaussianModel(GaussianModel):
def __init__(self, sh_degree: int):
self.active_sh_degree = 3
self.max_sh_degree = sh_degree
self._features_dc = torch.empty(0)
self._features_rest = torch.empty(0)
self._scaling = torch.empty(0)
self._rotation = torch.empty(0)
self._opacity = torch.empty(0)
self.optimizer = None
self.setup_functions()

def capture(self):
return (
self.active_sh_degree,
self._features_dc,
self._features_rest,
self._scaling,
self._rotation,
self._opacity,
self.optimizer.state_dict(),
)

def restore(self, model_args):
(
self.active_sh_degree,
self._features_dc,
self._features_rest,
self._scaling,
self._rotation,
self._opacity,
opt_dict,
) = model_args
self.setup_optimizer()
self.optimizer.load_state_dict(opt_dict)

def training_setup(self, training_args):
l = [
{'params': [self._features_dc], 'lr': 0.01, "name": "f_dc"},
{'params': [self._features_rest], 'lr': 0.0005, "name": "f_rest"},
{'params': [self._opacity], 'lr': 0.05, "name": "opacity"},
{'params': [self._scaling], 'lr': 0.005, "name": "scaling"},
{'params': [self._rotation], 'lr': 0.001, "name": "rotation"}
]

self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
22 changes: 14 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from gaussian_renderer import render, network_gui
import sys
import torch.nn.functional as F
from scene import Scene, GaussianModel, AppearanceModel
from scene import Scene, GaussianModel, BgGaussianModel, AppearanceModel
from utils.general_utils import safe_state
from utils.patchmatch import process_propagation
import uuid
Expand Down Expand Up @@ -119,8 +119,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
first_iter = 0
tb_writer = prepare_output_and_logger(dataset)
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians)
bg_gaussians = BgGaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, bg_gaussians)
gaussians.training_setup(opt)
bg_gaussians.training_setup(opt)
if checkpoint:
(model_params, first_iter) = torch.load(checkpoint)
gaussians.restore(model_params, opt)
Expand Down Expand Up @@ -161,10 +163,11 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
viewpoint_idx = randint(0, len(all_cameras)-1)
viewpoint_cam = all_cameras[viewpoint_idx]
# Set intervals for patch match
intervals = [-2, -1, 1, 2]
src_idxs = [viewpoint_idx+itv for itv in intervals if ((itv + viewpoint_idx > 0) and (itv + viewpoint_idx < len(viewpoint_stack)))]
process_propagation(viewpoint_stack, viewpoint_cam, gaussians, pipe, background, iteration, opt, src_idxs)
render_pkg = render(viewpoint_cam, gaussians, pipe, background, record_transmittance=(iteration < opt.densify_until_iter))
# intervals = [-2, -1, 1, 2]
# src_idxs = [viewpoint_idx+itv for itv in intervals if ((itv + viewpoint_idx > 0) and (itv + viewpoint_idx < len(viewpoint_stack)))]
# process_propagation(viewpoint_stack, viewpoint_cam, gaussians, pipe, background, iteration, opt, src_idxs)
render_pkg = render(viewpoint_cam, gaussians, pipe, background, record_transmittance=(iteration < opt.densify_until_iter), bg_gaussians=bg_gaussians)
# render_pkg = render(viewpoint_cam, gaussians, pipe, background, record_transmittance=(iteration < opt.densify_until_iter))
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]

gt_image = viewpoint_cam.original_image.cuda()
Expand Down Expand Up @@ -250,7 +253,7 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
tb_writer.add_scalar('train_loss_patches/dist_loss', ema_depth_for_log, iteration)
tb_writer.add_scalar('train_loss_patches/normal_loss', ema_normal_for_log, iteration)

training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
# training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
if (iteration in saving_iterations):
print("\n[ITER {}] Saving Gaussians".format(iteration))
scene.save(iteration)
Expand Down Expand Up @@ -289,6 +292,9 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
if appearances is not None:
appearances.optimizer.step()
appearances.optimizer.zero_grad(set_to_none = True)
if bg_gaussians is not None:
bg_gaussians.optimizer.step()
bg_gaussians.optimizer.zero_grad(set_to_none = True)

if (iteration in checkpoint_iterations):
print("\n[ITER {}] Saving Checkpoint".format(iteration))
Expand Down Expand Up @@ -410,7 +416,7 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
parser.add_argument('--port', type=int, default=6009)
parser.add_argument('--detect_anomaly', action='store_true', default=False)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[1, 7_000, 20_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[1, 7_000, 20_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[500, 7_000, 20_000, 30_000])
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--start_checkpoint", type=str, default = None)
Expand Down