Skip to content

Commit 9751f1f

Browse files
davnov134facebook-github-bot
authored andcommitted
Main training script
Summary: Implements the training script of NeRF. Reviewed By: nikhilaravi Differential Revision: D25684439 fbshipit-source-id: 8b19b6dc282eb6bf6e46ec4476bb0f13a84c90dd
1 parent 5b74911 commit 9751f1f

File tree

6 files changed

+466
-1
lines changed

6 files changed

+466
-1
lines changed

projects/nerf/__init__.py

Whitespace-only changes.

projects/nerf/configs/fern.yaml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
seed: 3
2+
resume: True
3+
stats_print_interval: 10
4+
validation_epoch_interval: 150
5+
checkpoint_epoch_interval: 150
6+
checkpoint_path: 'checkpoints/fern_pt3d.pth'
7+
data:
8+
dataset_name: 'fern'
9+
image_size: [378, 504] # [height, width]
10+
precache_rays: True
11+
test:
12+
mode: 'evaluation'
13+
trajectory_type: 'figure_eight'
14+
up: [0.0, 1.0, 0.0]
15+
scene_center: [0.0, 0.0, -2.0]
16+
n_frames: 100
17+
fps: 20
18+
optimizer:
19+
max_epochs: 37500
20+
lr: 0.0005
21+
lr_scheduler_step_size: 12500
22+
lr_scheduler_gamma: 0.1
23+
visualization:
24+
history_size: 10
25+
visdom: True
26+
visdom_server: 'localhost'
27+
visdom_port: 8097
28+
visdom_env: 'nerf_pytorch3d'
29+
raysampler:
30+
n_pts_per_ray: 64
31+
n_pts_per_ray_fine: 64
32+
n_rays_per_image: 1024
33+
min_depth: 1.2
34+
max_depth: 6.28
35+
stratified: True
36+
stratified_test: False
37+
chunk_size_test: 6000
38+
implicit_function:
39+
n_harmonic_functions_xyz: 10
40+
n_harmonic_functions_dir: 4
41+
n_hidden_neurons_xyz: 256
42+
n_hidden_neurons_dir: 128
43+
density_noise_std: 0.0
44+
n_layers_xyz: 8

projects/nerf/configs/lego.yaml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
seed: 3
2+
resume: True
3+
stats_print_interval: 10
4+
validation_epoch_interval: 30
5+
checkpoint_epoch_interval: 30
6+
checkpoint_path: 'checkpoints/lego_pt3d.pth'
7+
data:
8+
dataset_name: 'lego'
9+
image_size: [800, 800] # [height, width]
10+
precache_rays: True
11+
test:
12+
mode: 'evaluation'
13+
trajectory_type: 'circular'
14+
up: [0.0, 0.0, 1.0]
15+
scene_center: [0.0, 0.0, 0.0]
16+
n_frames: 100
17+
fps: 20
18+
optimizer:
19+
max_epochs: 20000
20+
lr: 0.0005
21+
lr_scheduler_step_size: 5000
22+
lr_scheduler_gamma: 0.1
23+
visualization:
24+
history_size: 10
25+
visdom: True
26+
visdom_server: 'localhost'
27+
visdom_port: 8097
28+
visdom_env: 'nerf_pytorch3d'
29+
raysampler:
30+
n_pts_per_ray: 64
31+
n_pts_per_ray_fine: 64
32+
n_rays_per_image: 1024
33+
min_depth: 2.0
34+
max_depth: 6.0
35+
stratified: True
36+
stratified_test: False
37+
chunk_size_test: 6000
38+
implicit_function:
39+
n_harmonic_functions_xyz: 10
40+
n_harmonic_functions_dir: 4
41+
n_hidden_neurons_xyz: 256
42+
n_hidden_neurons_dir: 128
43+
density_noise_std: 0.0
44+
n_layers_xyz: 8

projects/nerf/configs/pt3logo.yaml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
seed: 3
2+
resume: True
3+
stats_print_interval: 10
4+
validation_epoch_interval: 30
5+
checkpoint_epoch_interval: 30
6+
checkpoint_path: 'checkpoints/pt3logo_pt3d.pth'
7+
data:
8+
dataset_name: 'pt3logo'
9+
image_size: [512, 1024] # [height, width]
10+
precache_rays: True
11+
test:
12+
mode: 'export_video'
13+
trajectory_type: 'figure_eight'
14+
up: [0.0, -1.0, 0.0]
15+
scene_center: [0.0, 0.0, 0.0]
16+
n_frames: 100
17+
fps: 20
18+
optimizer:
19+
max_epochs: 100000
20+
lr: 0.0005
21+
lr_scheduler_step_size: 10000
22+
lr_scheduler_gamma: 0.1
23+
visualization:
24+
history_size: 20
25+
visdom: True
26+
visdom_server: 'localhost'
27+
visdom_port: 8097
28+
visdom_env: 'nerf_pytorch3d'
29+
raysampler:
30+
n_pts_per_ray: 64
31+
n_pts_per_ray_fine: 64
32+
n_rays_per_image: 1024
33+
min_depth: 8.0
34+
max_depth: 23.0
35+
stratified: True
36+
stratified_test: False
37+
chunk_size_test: 6000
38+
implicit_function:
39+
n_harmonic_functions_xyz: 10
40+
n_harmonic_functions_dir: 4
41+
n_hidden_neurons_xyz: 256
42+
n_hidden_neurons_dir: 128
43+
density_noise_std: 0.0
44+
n_layers_xyz: 8

projects/nerf/nerf/nerf_renderer.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
from typing import Tuple, List, Optional
33

44
import torch
5-
from pytorch3d.renderer import ImplicitRenderer
5+
from pytorch3d.renderer import ImplicitRenderer, ray_bundle_to_ray_points
66
from pytorch3d.renderer.cameras import CamerasBase
7+
from pytorch3d.structures import Pointclouds
8+
from pytorch3d.vis.plotly_vis import plot_scene
9+
from visdom import Visdom
710

811
from .implicit_function import NeuralRadianceField
912
from .raymarcher import EmissionAbsorptionNeRFRaymarcher
@@ -357,3 +360,68 @@ def forward(
357360
)
358361

359362
return out, metrics
363+
364+
365+
def visualize_nerf_outputs(
366+
nerf_out: dict, output_cache: List, viz: Visdom, visdom_env: str
367+
):
368+
"""
369+
Visualizes the outputs of the `RadianceFieldRenderer`.
370+
371+
Args:
372+
nerf_out: An output of the validation rendering pass.
373+
output_cache: A list with outputs of several training render passes.
374+
viz: A visdom connection object.
375+
visdom_env: The name of visdom environment for visualization.
376+
"""
377+
378+
# Show the training images.
379+
ims = torch.stack([o["image"] for o in output_cache])
380+
ims = torch.cat(list(ims), dim=1)
381+
viz.image(
382+
ims.permute(2, 0, 1),
383+
env=visdom_env,
384+
win="images",
385+
opts={"title": "train_images"},
386+
)
387+
388+
# Show the coarse and fine renders together with the ground truth images.
389+
ims_full = torch.cat(
390+
[
391+
nerf_out[imvar][0].permute(2, 0, 1).detach().cpu().clamp(0.0, 1.0)
392+
for imvar in ("rgb_coarse", "rgb_fine", "rgb_gt")
393+
],
394+
dim=2,
395+
)
396+
viz.image(
397+
ims_full,
398+
env=visdom_env,
399+
win="images_full",
400+
opts={"title": "coarse | fine | target"},
401+
)
402+
403+
# Make a 3D plot of training cameras and their emitted rays.
404+
camera_trace = {
405+
f"camera_{ci:03d}": o["camera"].cpu() for ci, o in enumerate(output_cache)
406+
}
407+
ray_pts_trace = {
408+
f"ray_pts_{ci:03d}": Pointclouds(
409+
ray_bundle_to_ray_points(o["coarse_ray_bundle"])
410+
.detach()
411+
.cpu()
412+
.view(1, -1, 3)
413+
)
414+
for ci, o in enumerate(output_cache)
415+
}
416+
plotly_plot = plot_scene(
417+
{
418+
"training_scene": {
419+
**camera_trace,
420+
**ray_pts_trace,
421+
},
422+
},
423+
pointcloud_max_points=5000,
424+
pointcloud_marker_size=1,
425+
camera_scale=0.3,
426+
)
427+
viz.plotlyplot(plotly_plot, env=visdom_env, win="scenes")

0 commit comments

Comments
 (0)