Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename to splatfacto #2795

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/nerfology/methods/splat.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ Because gaussian splatting trains on *full images* instead of bundles of rays, t


### Running the Method
To run gaussian splatting, run `ns-train gaussian-splatting --data <data>`. Just like NeRF methods, the splat can be interactively viewed in the web-viewer, loaded from a checkpoint, rendered, and exported.
To run gaussian splatting, run `ns-train splatfacto --data <data>`. Just like NeRF methods, the splat can be interactively viewed in the web-viewer, loaded from a checkpoint, rendered, and exported.

#### Quality and Regularization
The default settings provided maintain a balance between speed, quality, and splat file size, but if you care more about quality than training speed or size, you can decrease the alpha cull threshold
(threshold to delete translucent gaussians) and disable culling after 15k steps like so: `ns-train gaussian-splatting --pipeline.model.cull_scale_thresh=0.005 --pipeline.model.continue_cull_post_densification=False --data <data>`
(threshold to delete translucent gaussians) and disable culling after 15k steps like so: `ns-train splatfacto --pipeline.model.cull_scale_thresh=0.005 --pipeline.model.continue_cull_post_densification=False --data <data>`

A common artifact in splatting is long, spikey gaussians. [PhysGaussian](https://xpandora.github.io/PhysGaussian/) proposes an example of a scale-regularizer that encourages gaussians to be more evenly shaped. To enable this, set the `use_scale_regularization` flag to `True`.

Expand Down
10 changes: 5 additions & 5 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@
from nerfstudio.field_components.temporal_distortions import TemporalDistortionKind
from nerfstudio.fields.sdf_field import SDFFieldConfig
from nerfstudio.models.depth_nerfacto import DepthNerfactoModelConfig
from nerfstudio.models.gaussian_splatting import GaussianSplattingModelConfig
from nerfstudio.models.generfacto import GenerfactoModelConfig
from nerfstudio.models.instant_ngp import InstantNGPModelConfig
from nerfstudio.models.mipnerf import MipNerfModel
from nerfstudio.models.nerfacto import NerfactoModelConfig
from nerfstudio.models.neus import NeuSModelConfig
from nerfstudio.models.neus_facto import NeuSFactoModelConfig
from nerfstudio.models.semantic_nerfw import SemanticNerfWModelConfig
from nerfstudio.models.splatfacto import SplatfactoModelConfig
from nerfstudio.models.tensorf import TensoRFModelConfig
from nerfstudio.models.vanilla_nerf import NeRFModel, VanillaModelConfig
from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig
Expand All @@ -80,7 +80,7 @@
"generfacto": "Generative Text to NeRF model",
"neus": "Implementation of NeuS. (slow)",
"neus-facto": "Implementation of NeuS-Facto. (slow)",
"gaussian-splatting": "Gaussian Splatting model",
"splatfacto": "Gaussian Splatting model",
}

method_configs["nerfacto"] = TrainerConfig(
Expand Down Expand Up @@ -588,8 +588,8 @@
vis="viewer",
)

method_configs["gaussian-splatting"] = TrainerConfig(
method_name="gaussian-splatting",
method_configs["splatfacto"] = TrainerConfig(
method_name="splatfacto",
steps_per_eval_image=100,
steps_per_eval_batch=0,
steps_per_save=2000,
Expand All @@ -601,7 +601,7 @@
datamanager=FullImageDatamanagerConfig(
dataparser=NerfstudioDataParserConfig(load_3D_points=True),
),
model=GaussianSplattingModelConfig(),
model=SplatfactoModelConfig(),
),
optimizers={
"xyz": {
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/data/dataparsers/nerfstudio_dataparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _generate_dataparser_outputs(self, split="train"):
else:
if not self.prompted_user:
CONSOLE.print(
"[bold yellow]Warning: load_3D_points set to true but no point cloud found. gaussian-splatting models will use random point cloud initialization."
"[bold yellow]Warning: load_3D_points set to true but no point cloud found. splatfacto models will use random point cloud initialization."
)
ply_file_path = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def projection_matrix(znear, zfar, fovx, fovy, device: Union[str, torch.device]


@dataclass
class GaussianSplattingModelConfig(ModelConfig):
class SplatfactoModelConfig(ModelConfig):
"""Gaussian Splatting Model Config"""

_target: Type = field(default_factory=lambda: GaussianSplattingModel)
_target: Type = field(default_factory=lambda: SplatfactoModel)
warmup_length: int = 500
"""period of steps where refinement is turned off"""
refine_every: int = 100
Expand Down Expand Up @@ -149,14 +149,14 @@ class GaussianSplattingModelConfig(ModelConfig):
"""


class GaussianSplattingModel(Model):
class SplatfactoModel(Model):
"""Gaussian Splatting model

Args:
config: Gaussian Splatting configuration to instantiate model
"""

config: GaussianSplattingModelConfig
config: SplatfactoModelConfig

def __init__(
self,
Expand Down
6 changes: 3 additions & 3 deletions nerfstudio/scripts/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from nerfstudio.exporter.exporter_utils import collect_camera_poses, generate_point_cloud, get_mesh_from_filename
from nerfstudio.exporter.marching_cubes import generate_mesh_with_multires_marching_cubes
from nerfstudio.fields.sdf_field import SDFField # noqa
from nerfstudio.models.gaussian_splatting import GaussianSplattingModel
from nerfstudio.models.splatfacto import SplatfactoModel
from nerfstudio.pipelines.base_pipeline import Pipeline, VanillaPipeline
from nerfstudio.utils.eval_utils import eval_setup
from nerfstudio.utils.rich_utils import CONSOLE
Expand Down Expand Up @@ -488,9 +488,9 @@ def main(self) -> None:

_, pipeline, _, _ = eval_setup(self.load_config)

assert isinstance(pipeline.model, GaussianSplattingModel)
assert isinstance(pipeline.model, SplatfactoModel)

model: GaussianSplattingModel = pipeline.model
model: SplatfactoModel = pipeline.model

filename = self.output_dir / "splat.ply"

Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/viewer/export_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from nerfstudio.data.scene_box import OrientedBox
from nerfstudio.models.base_model import Model
from nerfstudio.models.gaussian_splatting import GaussianSplattingModel
from nerfstudio.models.splatfacto import SplatfactoModel
from nerfstudio.viewer.control_panel import ControlPanel


Expand All @@ -32,7 +32,7 @@ def populate_export_tab(
config_path: Path,
viewer_model: Model,
) -> None:
viewing_gsplat = isinstance(viewer_model, GaussianSplattingModel)
viewing_gsplat = isinstance(viewer_model, SplatfactoModel)
if not viewing_gsplat:
crop_output = server.add_gui_checkbox("Use Crop", False)

Expand Down
6 changes: 3 additions & 3 deletions nerfstudio/viewer/render_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from nerfstudio.cameras.cameras import Cameras
from nerfstudio.model_components.renderers import background_color_override_context
from nerfstudio.models.gaussian_splatting import GaussianSplattingModel
from nerfstudio.models.splatfacto import SplatfactoModel
from nerfstudio.utils import colormaps, writer
from nerfstudio.utils.writer import GLOBAL_BUFFER, EventName, TimeWriter
from nerfstudio.viewer.utils import CameraState, get_camera
Expand Down Expand Up @@ -136,7 +136,7 @@ def _render_img(self, camera_state: CameraState):

with TimeWriter(None, None, write=False) as vis_t:
with self.viewer.train_lock if self.viewer.train_lock is not None else contextlib.nullcontext():
if isinstance(self.viewer.get_model(), GaussianSplattingModel):
if isinstance(self.viewer.get_model(), SplatfactoModel):
color = self.viewer.control_panel.background_color
background_color = torch.tensor(
[color[0] / 255.0, color[1] / 255.0, color[2] / 255.0],
Expand Down Expand Up @@ -168,7 +168,7 @@ def _render_img(self, camera_state: CameraState):
self.viewer.get_model().train()
num_rays = (camera.height * camera.width).item()
if self.viewer.control_panel.layer_depth:
if isinstance(self.viewer.get_model(), GaussianSplattingModel):
if isinstance(self.viewer.get_model(), SplatfactoModel):
# Gaussians render much faster than we can send depth images, so we do some downsampling.
assert len(outputs["depth"].shape) == 3
assert outputs["depth"].shape[-1] == 1
Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/viewer/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from nerfstudio.configs import base_config as cfg
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.models.base_model import Model
from nerfstudio.models.gaussian_splatting import GaussianSplattingModel
from nerfstudio.models.splatfacto import SplatfactoModel
from nerfstudio.pipelines.base_pipeline import Pipeline
from nerfstudio.utils.decorators import check_main_thread, decorate_all
from nerfstudio.utils.writer import GLOBAL_BUFFER, EventName
Expand Down Expand Up @@ -250,7 +250,7 @@ def nested_folder_install(folder_labels: List[str], prev_labels: List[str], elem

# Diagnostics for Gaussian Splatting: where the points are at the start of training.
# This is hidden by default, it can be shown from the Viser UI's scene tree table.
if isinstance(pipeline.model, GaussianSplattingModel):
if isinstance(pipeline.model, SplatfactoModel):
self.viser_server.add_point_cloud(
"/gaussian_splatting_initial_points",
points=pipeline.model.means.numpy(force=True) * VISER_NERFSTUDIO_SCALE_RATIO,
Expand Down
4 changes: 2 additions & 2 deletions nerfstudio/viewer_legacy/server/render_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from nerfstudio.cameras.cameras import Cameras
from nerfstudio.model_components.renderers import background_color_override_context
from nerfstudio.models.gaussian_splatting import GaussianSplattingModel
from nerfstudio.models.splatfacto import SplatfactoModel
from nerfstudio.utils import colormaps, writer
from nerfstudio.utils.writer import GLOBAL_BUFFER, EventName, TimeWriter
from nerfstudio.viewer_legacy.server import viewer_utils
Expand Down Expand Up @@ -130,7 +130,7 @@ def _render_img(self, cam_msg: CameraMessage):
with self.viewer.train_lock if self.viewer.train_lock is not None else contextlib.nullcontext():
# TODO jake-austin: Make this check whether the model inherits from a camera based model or a ray based model
# TODO Zhuoyang: First made some dummy judgements, need to be fixed later
isGaussianSplattingModel = isinstance(self.viewer.get_model(), GaussianSplattingModel)
isGaussianSplattingModel = isinstance(self.viewer.get_model(), SplatfactoModel)
if isGaussianSplattingModel:
# TODO fix me before ship
camera_ray_bundle = camera.generate_rays(camera_indices=0, aabb_box=self.viewer.get_model().render_aabb)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"neus",
"generfacto",
"neus-facto",
"gaussian-splatting",
"splatfacto",
]


Expand Down
Loading