Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fast high resolution marching cubes up to 1024^3. #68

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def run_example(image_pil):
mc_resolution = gr.Slider(
label="Marching Cubes Resolution",
minimum=32,
maximum=320,
maximum=1024,
value=256,
step=32
)
Expand Down
67 changes: 63 additions & 4 deletions tsr/models/nerf_renderer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass
from typing import Dict
from typing import Dict, Optional

import torch
import torch.nn.functional as F
from einops import rearrange, reduce
from torchmcubes import marching_cubes

from ..utils import (
BaseModule,
Expand Down Expand Up @@ -38,20 +39,78 @@ def set_chunk_size(self, chunk_size: int):
), "chunk_size must be a non-negative integer (0 for no chunking)."
self.chunk_size = chunk_size

def interpolate_triplane(self, triplane: torch.Tensor, resolution: int):
coords = torch.linspace(-1.0, 1.0, resolution, device = triplane.device)
x, y = torch.meshgrid(coords, coords, indexing="ij")
verts2D = torch.cat([x.view(resolution, resolution,1), y.view(resolution, resolution,1)], dim = -1)
verts2D = verts2D.expand(3, -1, -1, -1)
return F.grid_sample(triplane, verts2D, align_corners=False,mode="bilinear") # [3 40 H W] xy xz yz

def block_based_marchingcube(self, decoder: torch.nn.Module, triplane: torch.Tensor, resolution: int, threshold, block_resolution = 128) -> torch.Tensor:
resolution += 1 # sample 1 more line of density, 1024 + 1 == 1025, 0 mapping to -1.0f, 512 mapping to 0.0f, 1025 mapping to 1.0f, for better floating point precision.
block_size = 2.0 * block_resolution / (resolution - 1)
voxel_size = block_size / block_resolution
interpolated = self.interpolate_triplane(triplane, resolution)

pos_list = []
indices_list = []
for x in range(0, resolution - 1, block_resolution):
size_x = resolution - x if x + block_resolution >= resolution else block_resolution + 1 # sample 1 more line of density, so marching cubes resolution match block_resolution
for y in range(0, resolution - 1, block_resolution):
size_y = resolution - y if y + block_resolution >= resolution else block_resolution + 1
for z in range(0, resolution - 1, block_resolution):
size_z = resolution - z if z + block_resolution >= resolution else block_resolution + 1
xyplane = interpolated[0:1, :, x:x+size_x, y:y+size_y].expand(size_z, -1, -1, -1, -1).permute(3, 4, 0, 1, 2)
xzplane = interpolated[1:2, :, x:x+size_x, z:z+size_z].expand(size_y, -1, -1, -1, -1).permute(3, 0, 4, 1, 2)
yzplane = interpolated[2:3, :, y:y+size_y, z:z+size_z].expand(size_x, -1, -1, -1, -1).permute(0, 3, 4, 1, 2)
sz = size_x * size_y * size_z
out = torch.cat([xyplane, xzplane, yzplane], dim=3).view(sz, 3, -1)

if self.cfg.feature_reduction == "concat":
out = out.view(sz, -1)
elif self.cfg.feature_reduction == "mean":
out = reduce(out, "N Np Cp -> N Cp", Np=3, reduction="mean")
else:
raise NotImplementedError
net_out = decoder(out)
out = None # discard samples
density = net_out["density"]
net_out = None # discard colors
density = get_activation(self.cfg.density_activation)(density + self.cfg.density_bias).view(size_x, size_y, size_z)
try: # now do the marching cube
v_pos, indices = marching_cubes(density.detach(), threshold)
except AttributeError:
print("torchmcubes was not compiled with CUDA support, use CPU version instead.")
v_pos, indices = self.mc_func(density.detach().cpu(), 0.0)
offset = torch.tensor([x * voxel_size - 1.0, y * voxel_size - 1.0, z * voxel_size - 1.0], device = triplane.device)
v_pos = v_pos[..., [2, 1, 0]] * voxel_size + offset

indices_list.append(indices)
pos_list.append(v_pos)

vertex_count = 0
for i in range(0, len(pos_list)):
indices_list[i] += vertex_count
vertex_count += pos_list[i].size(0)

return torch.cat(pos_list), torch.cat(indices_list)

def query_triplane(
self,
decoder: torch.nn.Module,
positions: torch.Tensor,
triplane: torch.Tensor,
scale_pos = True
) -> Dict[str, torch.Tensor]:
input_shape = positions.shape[:-1]
positions = positions.view(-1, 3)

# positions in (-radius, radius)
# normalized to (-1, 1) for grid sample
positions = scale_tensor(
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
)
if scale_pos:
positions = scale_tensor(
positions, (-self.cfg.radius, self.cfg.radius), (-1, 1)
)

def _query_chunk(x):
indices2D: torch.Tensor = torch.stack(
Expand Down
50 changes: 15 additions & 35 deletions tsr/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from omegaconf import OmegaConf
from PIL import Image

from .models.isosurface import MarchingCubeHelper
from .utils import (
BaseModule,
ImagePreprocessor,
Expand Down Expand Up @@ -160,44 +159,25 @@ def process_output(image: torch.FloatTensor):

return images

def set_marching_cubes_resolution(self, resolution: int):
if (
self.isosurface_helper is not None
and self.isosurface_helper.resolution == resolution
):
return
self.isosurface_helper = MarchingCubeHelper(resolution)

def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
self.set_marching_cubes_resolution(resolution)
meshes = []
for scene_code in scene_codes:
with torch.no_grad():
density = self.renderer.query_triplane(
self.decoder,
scale_tensor(
self.isosurface_helper.grid_vertices.to(scene_codes.device),
self.isosurface_helper.points_range,
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
),
v_pos, t_pos_idx = self.renderer.block_based_marchingcube(self.decoder.to(scene_codes.device),
scene_code,
)["density_act"]
v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
v_pos = scale_tensor(
v_pos,
self.isosurface_helper.points_range,
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
)
with torch.no_grad():
color = self.renderer.query_triplane(
self.decoder,
resolution,
threshold
)
color = self.renderer.query_triplane(self.decoder.to(scene_codes.device), v_pos.to(scene_codes.device), scene_code, False)["color"]
v_pos = scale_tensor(
v_pos,
scene_code,
)["color"]
mesh = trimesh.Trimesh(
vertices=v_pos.cpu().numpy(),
faces=t_pos_idx.cpu().numpy(),
vertex_colors=color.cpu().numpy(),
)
meshes.append(mesh)
(-1.0, 1.0),
(-self.renderer.cfg.radius, self.renderer.cfg.radius)
)
mesh = trimesh.Trimesh(
vertices=v_pos.cpu().numpy(),
faces=t_pos_idx.cpu().numpy(),
vertex_colors=color.cpu().numpy(),
)
meshes.append(mesh)
return meshes