From 7c98d9ac06a4c06333ee4aea65b1df2e56ad31b5 Mon Sep 17 00:00:00 2001 From: StarsTesla Date: Thu, 24 Aug 2023 21:08:18 +0800 Subject: [PATCH] add files --- .gitignore | 0 LICENSE | 0 README.md | 0 activation.py | 18 + encoding.py | 75 +++ freqencoder/__init__.py | 1 + freqencoder/backend.py | 41 ++ freqencoder/freq.py | 77 +++ freqencoder/setup.py | 51 ++ freqencoder/src/bindings.cpp | 8 + freqencoder/src/freqencoder.cu | 129 ++++ freqencoder/src/freqencoder.h | 10 + gridencoder/__init__.py | 1 + gridencoder/backend.py | 40 ++ gridencoder/grid.py | 185 ++++++ gridencoder/setup.py | 50 ++ gridencoder/src/bindings.cpp | 9 + gridencoder/src/gridencoder.cu | 642 +++++++++++++++++++ gridencoder/src/gridencoder.h | 17 + main.py | 287 +++++++++ nerf/blender.py | 162 +++++ nerf/clip.py | 38 ++ nerf/gui.py | 469 ++++++++++++++ nerf/llff.py | 573 +++++++++++++++++ nerf/network_grid.py | 193 ++++++ nerf/renderer.py | 688 ++++++++++++++++++++ nerf/sd.py | 220 +++++++ nerf/utils.py | 1087 ++++++++++++++++++++++++++++++++ optimizer.py | 325 ++++++++++ raymarching/__init__.py | 1 + raymarching/backend.py | 40 ++ raymarching/raymarching.py | 385 +++++++++++ raymarching/setup.py | 62 ++ raymarching/src/bindings.cpp | 19 + raymarching/src/raymarching.cu | 914 +++++++++++++++++++++++++++ raymarching/src/raymarching.h | 18 + requirements.txt | 56 ++ scripts/colmap2nerf.py | 331 ++++++++++ scripts/install_ext.sh | 4 + scripts/run.sh | 5 + shencoder/__init__.py | 1 + shencoder/backend.py | 40 ++ shencoder/setup.py | 50 ++ shencoder/sphere_harmonics.py | 87 +++ shencoder/src/bindings.cpp | 8 + shencoder/src/shencoder.cu | 439 +++++++++++++ shencoder/src/shencoder.h | 10 + 47 files changed, 7866 insertions(+) mode change 100644 => 100755 .gitignore mode change 100644 => 100755 LICENSE mode change 100644 => 100755 README.md create mode 100755 activation.py create mode 100755 encoding.py create mode 100755 freqencoder/__init__.py create mode 100755 freqencoder/backend.py create mode 100755 freqencoder/freq.py create mode 100755 freqencoder/setup.py create mode 100755 freqencoder/src/bindings.cpp create mode 100755 freqencoder/src/freqencoder.cu create mode 100755 freqencoder/src/freqencoder.h create mode 100755 gridencoder/__init__.py create mode 100755 gridencoder/backend.py create mode 100755 gridencoder/grid.py create mode 100755 gridencoder/setup.py create mode 100755 gridencoder/src/bindings.cpp create mode 100755 gridencoder/src/gridencoder.cu create mode 100755 gridencoder/src/gridencoder.h create mode 100755 main.py create mode 100755 nerf/blender.py create mode 100755 nerf/clip.py create mode 100755 nerf/gui.py create mode 100755 nerf/llff.py create mode 100755 nerf/network_grid.py create mode 100755 nerf/renderer.py create mode 100755 nerf/sd.py create mode 100755 nerf/utils.py create mode 100755 optimizer.py create mode 100755 raymarching/__init__.py create mode 100755 raymarching/backend.py create mode 100755 raymarching/raymarching.py create mode 100755 raymarching/setup.py create mode 100755 raymarching/src/bindings.cpp create mode 100755 raymarching/src/raymarching.cu create mode 100755 raymarching/src/raymarching.h create mode 100755 requirements.txt create mode 100755 scripts/colmap2nerf.py create mode 100755 scripts/install_ext.sh create mode 100755 scripts/run.sh create mode 100755 shencoder/__init__.py create mode 100755 shencoder/backend.py create mode 100755 shencoder/setup.py create mode 100755 shencoder/sphere_harmonics.py create mode 100755 shencoder/src/bindings.cpp create mode 100755 shencoder/src/shencoder.cu create mode 100755 shencoder/src/shencoder.h diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 diff --git a/LICENSE b/LICENSE old mode 100644 new mode 100755 diff --git a/README.md b/README.md old mode 100644 new mode 100755 diff --git a/activation.py b/activation.py new file mode 100755 index 0000000..c8edfd3 --- /dev/null +++ b/activation.py @@ -0,0 +1,18 @@ +import torch +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +class _trunc_exp(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float) + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): + x = ctx.saved_tensors[0] + return g * torch.exp(x.clamp(max=15)) + +trunc_exp = _trunc_exp.apply \ No newline at end of file diff --git a/encoding.py b/encoding.py new file mode 100755 index 0000000..531361b --- /dev/null +++ b/encoding.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FreqEncoder_torch(nn.Module): + def __init__(self, input_dim, max_freq_log2, N_freqs, + log_sampling=True, include_input=True, + periodic_fns=(torch.sin, torch.cos)): + + super().__init__() + + self.input_dim = input_dim + self.include_input = include_input + self.periodic_fns = periodic_fns + + self.output_dim = 0 + if self.include_input: + self.output_dim += self.input_dim + + self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) + + if log_sampling: + self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs) + else: + self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs) + + self.freq_bands = self.freq_bands.numpy().tolist() + + def forward(self, input, **kwargs): + + out = [] + if self.include_input: + out.append(input) + + for i in range(len(self.freq_bands)): + freq = self.freq_bands[i] + for p_fn in self.periodic_fns: + out.append(p_fn(input * freq)) + + out = torch.cat(out, dim=-1) + + return out + +def get_encoder(encoding, input_dim=3, + multires=6, + degree=4, + num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, + **kwargs): + + if encoding == 'None': + return lambda x, **kwargs: x, input_dim + + elif encoding == 'frequency_torch': + encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) + + elif encoding == 'frequency': # CUDA implementation, faster than torch. + from freqencoder import FreqEncoder + encoder = FreqEncoder(input_dim=input_dim, degree=multires) + + elif encoding == 'sphere_harmonics': + from shencoder import SHEncoder + encoder = SHEncoder(input_dim=input_dim, degree=degree) + + elif encoding == 'hashgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners) + + elif encoding == 'tiledgrid': + from gridencoder import GridEncoder + encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners) + + else: + raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]') + + return encoder, encoder.output_dim \ No newline at end of file diff --git a/freqencoder/__init__.py b/freqencoder/__init__.py new file mode 100755 index 0000000..69ec49c --- /dev/null +++ b/freqencoder/__init__.py @@ -0,0 +1 @@ +from .freq import FreqEncoder \ No newline at end of file diff --git a/freqencoder/backend.py b/freqencoder/backend.py new file mode 100755 index 0000000..3bd9131 --- /dev/null +++ b/freqencoder/backend.py @@ -0,0 +1,41 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_freqencoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/freqencoder/freq.py b/freqencoder/freq.py new file mode 100755 index 0000000..5cba1e6 --- /dev/null +++ b/freqencoder/freq.py @@ -0,0 +1,77 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _freqencoder as _backend +except ImportError: + from .backend import _backend + + +class _freq_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, output_dim): + # inputs: [B, input_dim], float + # RETURN: [B, F], float + + if not inputs.is_cuda: inputs = inputs.cuda() + inputs = inputs.contiguous() + + B, input_dim = inputs.shape # batch size, coord dim + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) + + ctx.save_for_backward(inputs, outputs) + ctx.dims = [B, input_dim, degree, output_dim] + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + grad = grad.contiguous() + inputs, outputs = ctx.saved_tensors + B, input_dim, degree, output_dim = ctx.dims + + grad_inputs = torch.zeros_like(inputs) + _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) + + return grad_inputs, None, None + + +freq_encode = _freq_encoder.apply + + +class FreqEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim + self.degree = degree + self.output_dim = input_dim + input_dim * 2 * degree + + def __repr__(self): + return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" + + def forward(self, inputs, **kwargs): + # inputs: [..., input_dim] + # return: [..., ] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = freq_encode(inputs, self.degree, self.output_dim) + + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs \ No newline at end of file diff --git a/freqencoder/setup.py b/freqencoder/setup.py new file mode 100755 index 0000000..3eb4af7 --- /dev/null +++ b/freqencoder/setup.py @@ -0,0 +1,51 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', + '-use_fast_math' +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='freqencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_freqencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'freqencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/freqencoder/src/bindings.cpp b/freqencoder/src/bindings.cpp new file mode 100755 index 0000000..bb5f285 --- /dev/null +++ b/freqencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "freqencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); + m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); +} \ No newline at end of file diff --git a/freqencoder/src/freqencoder.cu b/freqencoder/src/freqencoder.cu new file mode 100755 index 0000000..072da74 --- /dev/null +++ b/freqencoder/src/freqencoder.cu @@ -0,0 +1,129 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + +inline constexpr __device__ float PI() { return 3.141592653589793f; } + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +// inputs: [B, D] +// outputs: [B, C], C = D + D * deg * 2 +__global__ void kernel_freq( + const float * __restrict__ inputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * outputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * C) return; + + // get index + const uint32_t b = t / C; + const uint32_t c = t - b * C; // t % C; + + // locate + inputs += b * D; + outputs += t; + + // write self + if (c < D) { + outputs[0] = inputs[c]; + // write freq + } else { + const uint32_t col = c / D - 1; + const uint32_t d = c % D; + const uint32_t freq = col / 2; + const float phase_shift = (col % 2) * (PI() / 2); + outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); + } +} + +// grad: [B, C], C = D + D * deg * 2 +// outputs: [B, C] +// grad_inputs: [B, D] +__global__ void kernel_freq_backward( + const float * __restrict__ grad, + const float * __restrict__ outputs, + uint32_t B, uint32_t D, uint32_t deg, uint32_t C, + float * grad_inputs +) { + // parallel on per-element + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; // t % D; + + // locate + grad += b * C; + outputs += b * C; + grad_inputs += t; + + // register + float result = grad[d]; + grad += D; + outputs += D; + + for (uint32_t f = 0; f < deg; f++) { + result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); + grad += 2 * D; + outputs += 2 * D; + } + + // write + grad_inputs[0] = result; +} + + +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { + CHECK_CUDA(inputs); + CHECK_CUDA(outputs); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(outputs); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(outputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq<<>>(inputs.data_ptr(), B, D, deg, C, outputs.data_ptr()); +} + + +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { + CHECK_CUDA(grad); + CHECK_CUDA(outputs); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(outputs); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(outputs); + CHECK_IS_FLOATING(grad_inputs); + + static constexpr uint32_t N_THREADS = 128; + + kernel_freq_backward<<>>(grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr()); +} \ No newline at end of file diff --git a/freqencoder/src/freqencoder.h b/freqencoder/src/freqencoder.h new file mode 100755 index 0000000..34f28c7 --- /dev/null +++ b/freqencoder/src/freqencoder.h @@ -0,0 +1,10 @@ +# pragma once + +#include +#include + +// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) +void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); + +// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) +void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); \ No newline at end of file diff --git a/gridencoder/__init__.py b/gridencoder/__init__.py new file mode 100755 index 0000000..f1476ce --- /dev/null +++ b/gridencoder/__init__.py @@ -0,0 +1 @@ +from .grid import GridEncoder \ No newline at end of file diff --git a/gridencoder/backend.py b/gridencoder/backend.py new file mode 100755 index 0000000..d99acb1 --- /dev/null +++ b/gridencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_grid_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/gridencoder/grid.py b/gridencoder/grid.py new file mode 100755 index 0000000..32b8bea --- /dev/null +++ b/gridencoder/grid.py @@ -0,0 +1,185 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _gridencoder as _backend +except ImportError: + from .backend import _backend + +_gridtype_to_id = { + 'hash': 0, + 'tiled': 1, +} + +_interp_to_id = { + 'linear': 0, + 'smoothstep': 1, +} + +class _grid_encode(Function): + @staticmethod + @custom_fwd + def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0): + # inputs: [B, D], float in [0, 1] + # embeddings: [sO, C], float + # offsets: [L + 1], int + # RETURN: [B, F], float + + inputs = inputs.contiguous() + + B, D = inputs.shape # batch size, coord dim + L = offsets.shape[0] - 1 # level + C = embeddings.shape[1] # embedding dim for each level + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = base_resolution # base resolution + + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) + # if C % 2 != 0, force float, since half for atomicAdd is very slow. + if torch.is_autocast_enabled() and C % 2 == 0: + embeddings = embeddings.to(torch.half) + + # L first, optimize cache for cuda kernel, but needs an extra permute later + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) + + if calc_grad_inputs: + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) + else: + dy_dx = None + + _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation) + + # permute back to [B, L * C] + outputs = outputs.permute(1, 0, 2).reshape(B, L * C) + + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) + ctx.dims = [B, D, C, L, S, H, gridtype, interpolation] + ctx.align_corners = align_corners + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors + B, D, C, L, S, H, gridtype, interpolation = ctx.dims + align_corners = ctx.align_corners + + # grad: [B, L * C] --> [L, B, C] + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() + + grad_embeddings = torch.zeros_like(embeddings) + + if dy_dx is not None: + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) + else: + grad_inputs = None + + _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation) + + if dy_dx is not None: + grad_inputs = grad_inputs.to(inputs.dtype) + + return grad_inputs, grad_embeddings, None, None, None, None, None, None, None + + + +grid_encode = _grid_encode.apply + + +class GridEncoder(nn.Module): + def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype='hash', align_corners=False, interpolation='linear'): + super().__init__() + + # the finest resolution desired at the last level, if provided, overridee per_level_scale + if desired_resolution is not None: + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) + + self.input_dim = input_dim # coord dims, 2 or 3 + self.num_levels = num_levels # num levels, each level multiply resolution by 2 + self.level_dim = level_dim # encode channels per level + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. + self.log2_hashmap_size = log2_hashmap_size + self.base_resolution = base_resolution + self.output_dim = num_levels * level_dim + self.gridtype = gridtype + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" + self.interpolation = interpolation + self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" + self.align_corners = align_corners + + # allocate parameters + offsets = [] + offset = 0 + self.max_params = 2 ** log2_hashmap_size + for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + params_in_level = min(self.max_params, (resolution if align_corners else resolution + 1) ** input_dim) # limit max number + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible + offsets.append(offset) + offset += params_in_level + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) + + self.n_params = offsets[-1] * level_dim + + # parameters + self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) + + self.reset_parameters() + + def reset_parameters(self): + std = 1e-4 + self.embeddings.data.uniform_(-std, std) + + def __repr__(self): + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" + + def forward(self, inputs, bound=1): + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] + # return: [..., num_levels * level_dim] + + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.view(-1, self.input_dim) + + outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id) + outputs = outputs.view(prefix_shape + [self.output_dim]) + + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs + + # always run in float precision! + @torch.cuda.amp.autocast(enabled=False) + def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): + # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. + + D = self.input_dim + C = self.embeddings.shape[1] # embedding dim for each level + L = self.offsets.shape[0] - 1 # level + S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = self.base_resolution # base resolution + + if inputs is None: + # randomized in [0, 1] + inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) + else: + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + inputs = inputs.view(-1, self.input_dim) + B = inputs.shape[0] + + if self.embeddings.grad is None: + raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') + + _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners) \ No newline at end of file diff --git a/gridencoder/setup.py b/gridencoder/setup.py new file mode 100755 index 0000000..714bf1c --- /dev/null +++ b/gridencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='gridencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_gridencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/gridencoder/src/bindings.cpp b/gridencoder/src/bindings.cpp new file mode 100755 index 0000000..93dea94 --- /dev/null +++ b/gridencoder/src/bindings.cpp @@ -0,0 +1,9 @@ +#include + +#include "gridencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); + m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); + m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); +} \ No newline at end of file diff --git a/gridencoder/src/gridencoder.cu b/gridencoder/src/gridencoder.cu new file mode 100755 index 0000000..fdd49cb --- /dev/null +++ b/gridencoder/src/gridencoder.cu @@ -0,0 +1,642 @@ +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here! + __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) { + // requires CUDA >= 10 and ARCH >= 70 + // this is very slow compared to float or __half2, never use it. + //return atomicAdd(reinterpret_cast<__half*>(address), val); +} + + +template +__host__ __device__ inline T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) { + return min(max(v, lo), hi); +} + +template +__device__ inline T smoothstep(T val) { + return val*val*(3.0f - 2.0f * val); +} + +template +__device__ inline T smoothstep_derivative(T val) { + return 6*val*(1.0f - val); +} + + +template +__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { + + // coherent type of hashing + constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u }; + + uint32_t result = 0; + #pragma unroll + for (uint32_t i = 0; i < D; ++i) { + result ^= pos_grid[i] * primes[i]; + } + + return result; +} + + +template +__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { + uint32_t stride = 1; + uint32_t index = 0; + + #pragma unroll + for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { + index += pos_grid[d] * stride; + stride *= align_corners ? resolution: (resolution + 1); + } + + // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. + // gridtype: 0 == hash, 1 == tiled + if (gridtype == 0 && stride > hashmap_size) { + index = fast_hash(pos_grid); + } + + return (index % hashmap_size) * C + ch; +} + + +template +__global__ void kernel_grid( + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ outputs, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + scalar_t * __restrict__ dy_dx, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + grid += (uint32_t)offsets[level] * C; + inputs += b * D; + outputs += level * B * C + b * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + // if input out of bound, just set output to 0 + if (flag_oob) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = 0; + } + if (dy_dx) { + dy_dx += b * D * L * C + level * D * C; // B L D C + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[d * C + ch] = 0; + } + } + } + return; + } + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate (always use float for precision!) + float pos[D]; + float pos_deriv[D] = {1.0f}; // linear deriv is default to 1 + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos_deriv[d] = smoothstep_derivative(pos[d]); + pos[d] = smoothstep(pos[d]); + } + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // interpolate + scalar_t results[C] = {0}; // temp results in register + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + // writing to register (fast) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results[ch] += w * grid[index + ch]; + } + + //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = results[ch]; + } + + // prepare dy_dx + // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 + if (dy_dx) { + + dy_dx += b * D * L * C + level * D * C; // B L D C + + #pragma unroll + for (uint32_t gd = 0; gd < D; gd++) { + + scalar_t results_grad[C] = {0}; + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { + float w = scale; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t nd = 0; nd < D - 1; nd++) { + const uint32_t d = (nd >= gd) ? (nd + 1) : nd; + + if ((idx & (1 << nd)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + pos_grid_local[gd] = pos_grid[gd]; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + pos_grid_local[gd] = pos_grid[gd] + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd]; + } + } + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[gd * C + ch] = results_grad[ch]; + } + } + } +} + + +template +__global__ void kernel_grid_backward( + const scalar_t * __restrict__ grad, + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ grad_grid, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; + if (b >= B) return; + + const uint32_t level = blockIdx.y; + const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; + + // locate + grad_grid += offsets[level] * C; + inputs += b * D; + grad += level * B * C + b * C + ch; // L, B, C + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // check input range (should be in [0, 1]) + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + return; // grad is init as 0, so we simply return. + } + } + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos[d] = smoothstep(pos[d]); + } + } + + scalar_t grad_cur[N_C] = {0}; // fetch to register + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + grad_cur[c] = grad[c]; + } + + // interpolate + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); + + // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 + // TODO: use float which is better than __half, if N_C % 2 != 0 + if (std::is_same::value && N_C % 2 == 0) { + #pragma unroll + for (uint32_t c = 0; c < N_C; c += 2) { + // process two __half at once (by interpreting as a __half2) + __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; + atomicAdd((__half2*)&grad_grid[index + c], v); + } + // float, or __half when N_C % 2 != 0 (which means C == 1) + } else { + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + atomicAdd(&grad_grid[index + c], w * grad_cur[c]); + } + } + } +} + + +template +__global__ void kernel_input_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ dy_dx, + scalar_t * __restrict__ grad_inputs, + uint32_t B, uint32_t L +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; + + dy_dx += b * L * D * C; + + scalar_t result = 0; + + # pragma unroll + for (int l = 0; l < L; l++) { + # pragma unroll + for (int ch = 0; ch < C; ch++) { + result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; + } + } + + grad_inputs[t] = result; +} + + +template +void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) +// H: base resolution +// dy_dx: [B, L * D * C] +template +void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +template +void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 256; + const uint32_t N_C = std::min(2u, C); // n_features_per_thread + const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; + switch (C) { + case 1: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 2: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 4: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 8: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +// grad: [L, B, C], float +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// grad_embeddings: [sO, C] +// H: base resolution +template +void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + + +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grid_encode_forward", ([&] { + grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); +} + +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(grad_embeddings); + // CHECK_CUDA(dy_dx); + // CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(grad_embeddings); + // CHECK_CONTIGUOUS(dy_dx); + // CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(grad_embeddings); + // CHECK_IS_FLOATING(dy_dx); + // CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "grid_encode_backward", ([&] { + grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); + +} + + +template +__global__ void kernel_grad_tv( + const scalar_t * __restrict__ inputs, + const scalar_t * __restrict__ grid, + scalar_t * __restrict__ grad, + const int * __restrict__ offsets, + const float weight, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + inputs += b * D; + grid += (uint32_t)offsets[level] * C; + grad += (uint32_t)offsets[level] * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + + // if input out of bound, do nothing + if (flag_oob) return; + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; // [0, resolution] + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + // pos[d] -= (float)pos_grid[d]; // not used + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // total variation on pos_grid + scalar_t results[C] = {0}; // temp results in register + scalar_t idelta[C] = {0}; + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + scalar_t w = weight / (2 * D); + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + + uint32_t cur_d = pos_grid[d]; + scalar_t grad_val; + + // right side + if (cur_d < resolution) { + pos_grid[d] = cur_d + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_right + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_right + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // left side + if (cur_d > 0) { + pos_grid[d] = cur_d - 1; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_left + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_left + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // reset + pos_grid[d] = cur_d; + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // index may collide, so use atomic! + atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f)); + } + +} + + +template +void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 2: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 8: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +template +void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + switch (D) { + case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grad_total_variation", ([&] { + grad_total_variation_cuda(inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners); + })); +} \ No newline at end of file diff --git a/gridencoder/src/gridencoder.h b/gridencoder/src/gridencoder.h new file mode 100755 index 0000000..1b38575 --- /dev/null +++ b/gridencoder/src/gridencoder.h @@ -0,0 +1,17 @@ +#ifndef _HASH_ENCODE_H +#define _HASH_ENCODE_H + +#include +#include + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [B, L * C], float +// H: base resolution +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); + +#endif \ No newline at end of file diff --git a/main.py b/main.py new file mode 100755 index 0000000..3d10b6d --- /dev/null +++ b/main.py @@ -0,0 +1,287 @@ +import torch +import argparse + +from nerf.provider import NeRFDataset +from nerf.utils import * + +from nerf.gui import NeRFGUI +from nerf.blender import BlenderDataset +from nerf.llff import LLFFDataset +# from nerf.llff_pre import LLFFDataset +# torch.autograd.set_detect_anomaly(True) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--text', default=None, help="text prompt") + parser.add_argument('--text_bg', default=None, help="background prompt") + parser.add_argument('--negative', default='', type=str, + help="negative text prompt") + parser.add_argument('-O', action='store_true', + help="equals --fp16 --cuda_ray --dir_text") + parser.add_argument('-O2', action='store_true', + help="equals --backbone vanilla --dir_text") + parser.add_argument('--test', action='store_true', help="test mode") + parser.add_argument('--save_mesh', action='store_true', + help="export an obj mesh with texture") + parser.add_argument('--eval_interval', type=int, default=1, + help="evaluate on the valid set every interval epochs") + parser.add_argument('--workspace', type=str, default='workspace') + parser.add_argument('--guidance', type=str, default='stable-diffusion', + help='choose from [stable-diffusion, clip]') + parser.add_argument('--seed', type=int, default=0) + + # training options + parser.add_argument('--iters', type=int, default=10000, + help="training iters") + parser.add_argument('--lr', type=float, default=1e-3, + help="max learning rate") + parser.add_argument('--warm_iters', type=int, + default=500, help="training iters") + parser.add_argument('--min_lr', type=float, default=1e-4, + help="minimal learning rate") + parser.add_argument('--ckpt', type=str, default='latest') + parser.add_argument('--cuda_ray', action='store_true', + help="use CUDA raymarching instead of pytorch") + parser.add_argument('--max_steps', type=int, default=512, + help="max num steps sampled per ray (only valid when using --cuda_ray)") + parser.add_argument('--num_steps', type=int, default=64, + help="num steps sampled per ray (only valid when not using --cuda_ray)") + parser.add_argument('--upsample_steps', type=int, default=32, + help="num steps up-sampled per ray (only valid when not using --cuda_ray)") + parser.add_argument('--update_extra_interval', type=int, default=16, + help="iter interval to update extra status (only valid when using --cuda_ray)") + parser.add_argument('--max_ray_batch', type=int, default=4096, + help="batch size of rays at inference to avoid OOM (only valid when not using --cuda_ray)") + parser.add_argument('--albedo', action='store_true', + help="only use albedo shading to train, overrides --albedo_iters") + parser.add_argument('--albedo_iters', type=int, default=1000, + help="training iters that only use albedo shading") + parser.add_argument('--uniform_sphere_rate', type=float, default=0.5, + help="likelihood of sampling camera location uniformly on the sphere surface area") + # model options + parser.add_argument('--bg_radius', type=float, default=0, + help="if positive, use a background model at sphere(bg_radius)") + parser.add_argument('--density_thresh', type=float, default=10, + help="threshold for density grid to be occupied") + parser.add_argument('--blob_density', type=float, default=10, + help="max (center) density for the gaussian density blob") + parser.add_argument('--blob_radius', type=float, default=0.3, + help="control the radius for the gaussian density blob") + # network backbone + parser.add_argument('--fp16', action='store_true', + help="use amp mixed precision training") + parser.add_argument('--backbone', type=str, default='grid', + choices=['grid', 'vanilla'], help="nerf backbone") + parser.add_argument('--optim', type=str, default='adan', + choices=['adan', 'adam', 'adamw'], help="optimizer") + parser.add_argument('--sd_version', type=str, default='2.0', + choices=['1.5', '2.0'], help="stable diffusion version") + parser.add_argument('--hf_key', type=str, default=None, + help="hugging face Stable diffusion model key") + # rendering resolution in training, decrease this if CUDA OOM. + parser.add_argument('--w', type=int, default=400, + help="render width for NeRF in training") + parser.add_argument('--h', type=int, default=400, + help="render height for NeRF in training") + parser.add_argument('--jitter_pose', action='store_true', + help="add jitters to the randomly sampled camera poses") + + # dataset options + + parser.add_argument('--bound', type=float, default=1.3, + help="assume the scene is bounded in box(-bound, bound)") + parser.add_argument('--dt_gamma', type=float, default=0, + help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)") + parser.add_argument('--min_near', type=float, default=0.1, + help="minimum near distance for camera") + parser.add_argument('--radius_range', type=float, nargs='*', + default=[1.0, 1.5], help="training camera radius range") + parser.add_argument('--fovy_range', type=float, nargs='*', + default=[40, 70], help="training camera fovy range") + parser.add_argument('--dir_text', action='store_true', + help="direction-encode the text prompt, by appending front/side/back/overhead view") + parser.add_argument('--suppress_face', action='store_true', + help="also use negative dir text prompt.") + parser.add_argument('--angle_overhead', type=float, default=30, + help="[0, angle_overhead] is the overhead region") + parser.add_argument('--angle_front', type=float, default=60, + help="[0, angle_front] is the front region, [180, 180+angle_front] the back region, otherwise the side region.") + + # GUI options + parser.add_argument('--gui', action='store_true', help="start a GUI") + parser.add_argument('--W', type=int, default=1920, help="GUI width") + parser.add_argument('--H', type=int, default=1080, help="GUI height") + parser.add_argument('--radius', type=float, default=3, + help="default GUI camera radius from center") + parser.add_argument('--fovy', type=float, default=60, + help="default GUI camera fovy") + parser.add_argument('--light_theta', type=float, default=60, + help="default GUI light direction in [0, 180], corresponding to elevation [90, -90]") + parser.add_argument('--light_phi', type=float, default=0, + help="default GUI light direction in [0, 360), azimuth") + parser.add_argument('--max_spp', type=int, default=1, + help="GUI rendering max sample per pixel") + + # for scene + parser.add_argument('--img_wh', nargs="+", type=int, default=[504, 378], # [252, 189] + help='resolution (img_w, img_h) of the image') + parser.add_argument('--data_dir', type=str, default='../') + parser.add_argument('--exp_name', type=str, default='flower') + parser.add_argument('--data_type', type=str, default='llff') + parser.add_argument('--spheric_poses', action='store_true') + + parser.add_argument('--pretrained', type=bool, default=False) + + opt = parser.parse_args() + + if opt.O: + opt.fp16 = True + opt.dir_text = False + opt.cuda_ray = True + + elif opt.O2: + # only use fp16 if not evaluating normals (else lead to NaNs in training...) + if opt.albedo: + opt.fp16 = True + opt.dir_text = False + opt.backbone = 'vanilla' + + if opt.albedo: + opt.albedo_iters = opt.iters + + from nerf.network_grid import NeRFNetwork + + print(opt) + + seed_everything(opt.seed) + + model = NeRFNetwork(opt) + + print(model) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if opt.test: + guidance = None # no need to load guidance model at test + clip_guidance = None + trainer = Trainer('df', opt, model, guidance, clip_guidance, device=device, + workspace=opt.workspace, fp16=opt.fp16, use_checkpoint=opt.ckpt) + + if opt.gui: + gui = NeRFGUI(opt, trainer) + gui.render() + + else: + # test_loader = NeRFDataset(opt, device=device, type='test', H=opt.H, W=opt.W, size=100).dataloader() + + dargs = { + 'root_dir': os.path.join(opt.data_dir, opt.exp_name), + 'img_wh': tuple(opt.img_wh)} + + print("test data.......") + if opt.data_type == 'llff': + dargs['spheric_poses'] = opt.spheric_poses + dargs['val_num'] = 1 + # print("test data .......llff") + test_dataset = LLFFDataset( + device=device, split='test', **dargs) # 修改过 + else: + test_dataset = BlenderDataset( + split='test', device=device, **dargs) # 修改过 + + test_loader = DataLoader(test_dataset, # 修改过 + shuffle=False, + num_workers=0, + batch_size=1) + # print("test loader.......................") + trainer.test(test_loader, write_video=True) + + if opt.save_mesh: + trainer.save_mesh(resolution=256) + + else: + + dargs = { + 'root_dir': os.path.join(opt.data_dir, opt.exp_name), + 'img_wh': tuple(opt.img_wh)} + + if opt.data_type == 'llff': + dargs['spheric_poses'] = opt.spheric_poses + dargs['val_num'] = 1 + train_dataset = LLFFDataset( + device=device, split='train', **dargs) + else: + train_dataset = BlenderDataset( + split='train', device=device, **dargs) + train_loader = DataLoader(train_dataset, + shuffle=True, + num_workers=0, + batch_size=1, + pin_memory=False) + + if opt.optim == 'adan': + from optimizer import Adan + # Adan usually requires a larger LR + + def optimizer(model): return Adan(model.get_params( + 5 * opt.lr), eps=1e-8, weight_decay=2e-5, max_grad_norm=5.0, foreach=False) + else: # adam + def optimizer(model): return torch.optim.Adam( + model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15) + + if opt.backbone == 'vanilla': + def warm_up_with_cosine_lr(iter): return iter / opt.warm_iters if iter <= opt.warm_iters \ + else max(0.5 * (math.cos((iter - opt.warm_iters) / (opt.iters - opt.warm_iters) * math.pi) + 1), + opt.min_lr / opt.lr) + + def scheduler(optimizer): return optim.lr_scheduler.LambdaLR( + optimizer, warm_up_with_cosine_lr) + else: + # scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1) # fixed + def scheduler(optimizer): return optim.lr_scheduler.LambdaLR( + optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)) + if opt.pretrained: + if opt.guidance == 'stable-diffusion': + from nerf.sd import StableDiffusion + from nerf.clip import CLIP + guidance = StableDiffusion(device, opt.sd_version, opt.hf_key) + clip_guidance = CLIP(device) + + elif opt.guidance == 'clip': + from nerf.clip import CLIP + guidance = CLIP(device) + else: + raise NotImplementedError( + f'--guidance {opt.guidance} is not implemented.') + else: + guidance = None + clip_guidance = None + + trainer = Trainer('df', opt, model, guidance, clip_guidance, device=device, workspace=opt.workspace, optimizer=optimizer, ema_decay=None, + fp16=opt.fp16, lr_scheduler=scheduler, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, scheduler_update_every_step=True, pretrained=opt.pretained) + + if opt.gui: + trainer.train_loader = train_loader # attach dataloader to trainer + + gui = NeRFGUI(opt, trainer) + gui.render() + + else: + # valid_loader = NeRFDataset(opt, device=device, type='val', H=opt.H, W=opt.W, size=5).dataloader() + + if opt.data_type == 'llff': + dargs['spheric_poses'] = opt.spheric_poses + dargs['val_num'] = 1 + val_dataset = LLFFDataset( + device=device, split='val', **dargs) # 修改过 + else: + val_dataset = BlenderDataset( + split='val', device=device, **dargs) # 修改过 + valid_loader = DataLoader(val_dataset, # 修改过 + shuffle=False, + num_workers=0, + batch_size=1, + pin_memory=False) + max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) + trainer.train(train_loader, valid_loader, max_epoch) diff --git a/nerf/blender.py b/nerf/blender.py new file mode 100755 index 0000000..fad9a96 --- /dev/null +++ b/nerf/blender.py @@ -0,0 +1,162 @@ +import os +import json + +import torch +import torch.nn as nn +from torch.utils.data import Dataset +from torchvision import transforms as T + +import numpy as np +from PIL import Image + +import glob + +from kornia import create_meshgrid + +# This code is borrowed from https://github.com/kwea123/nerf_pl/blob/master/datasets/blender.py +# I modified the return batch for whole resolution image + + +def get_ray_directions(H, W, focal): + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + i, j = grid.unbind(-1) + directions = torch.stack( + [(i-W/2)/focal, -(j-H/2)/focal, -torch.ones_like(i)], -1) + return directions # (H, W, 3) + + +def get_rays(directions, c2w): + rays_d = directions @ c2w[:, :3].T + rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) + rays_o = c2w[:, 3].expand(rays_d.shape) # H, W, 3 + rays_d = rays_d.view(-1, 3) # H*W, 3 + rays_o = rays_o.view(-1, 3) + + return rays_o, rays_d + + +def get_ndc_rays(H, W, focal, near, rays_o, rays_d): + """ + Transform rays from world coordinate to NDC. + NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis. + For detailed derivation, please see: + http://www.songho.ca/opengl/gl_projectionmatrix.html + https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf + + In practice, use NDC "if and only if" the scene is unbounded (has a large depth). + See https://github.com/bmild/nerf/issues/18 + + Inputs: + H, W, focal: image height, width and focal length + near: (N_rays) or float, the depths of the near plane + rays_o: (N_rays, 3), the origin of the rays in world coordinate + rays_d: (N_rays, 3), the direction of the rays in world coordinate + + Outputs: + rays_o: (N_rays, 3), the origin of the rays in NDC + rays_d: (N_rays, 3), the direction of the rays in NDC + """ + # Shift ray origins to near plane + t = -(near + rays_o[..., 2]) / rays_d[..., 2] + rays_o = rays_o + t[..., None] * rays_d + + # Store some intermediate homogeneous results + ox_oz = rays_o[..., 0] / rays_o[..., 2] + oy_oz = rays_o[..., 1] / rays_o[..., 2] + + # Projection + o0 = -1./(W/(2.*focal)) * ox_oz + o1 = -1./(H/(2.*focal)) * oy_oz + o2 = 1. + 2. * near / rays_o[..., 2] + + d0 = -1./(W/(2.*focal)) * (rays_d[..., 0]/rays_d[..., 2] - ox_oz) + d1 = -1./(H/(2.*focal)) * (rays_d[..., 1]/rays_d[..., 2] - oy_oz) + d2 = 1 - o2 + + rays_o = torch.stack([o0, o1, o2], -1) # (B, 3) + rays_d = torch.stack([d0, d1, d2], -1) # (B, 3) + + return rays_o, rays_d + + +def normalize(v): + """Normalize a vector.""" + return v/np.linalg.norm(v) + + +class BlenderDataset(Dataset): + def __init__(self, device, root_dir, split='train', img_wh=(80, 80)): + self.root_dir = root_dir + self.split = split + # assert img_wh[0] == img_wh[1] + self.img_wh = img_wh + self.define_transform() + + self.read_meta() + self.white_back = True + self.device = device + + def define_transform(self): + self.transform = T.ToTensor() + + def read_meta(self): + with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: + self.meta = json.load(f) + + w, h = self.img_wh + self.focal = 0.5 * 800 / np.tan(0.5*self.meta['camera_angle_x']) + + self.focal *= self.img_wh[0]/800 + + self.near = 2.0 + self.far = 6.0 + self.bounds = np.array([self.near, self.far]) + + self.directions = get_ray_directions(h, w, self.focal) + + def __len__(self): + if self.split == "train": + return len(self.meta['frames']) + if self.split == "val": + return 1 # only validate 1 images (to support <=8 gpus) + return len(self.meta["frames"]) # for test + + def __getitem__(self, idx): + frame = self.meta["frames"][idx] + + c2w = torch.FloatTensor(frame['transform_matrix'])[:3, :4] + + img = Image.open(os.path.join( + self.root_dir, f"{frame['file_path']}.png")).convert("RGB") + + file_name = frame['file_path'].split('/')[-1] + # depth = Image.open(os.path.join(self.root_dir, f"train_depth/{file_name}-dpt_beit_large_512.png")) + + mask_img = Image.open(os.path.join( + self.root_dir, f'train_mask/{file_name}.png')).convert("RGB") + img = img.resize(self.img_wh, Image.LANCZOS) + mask_img = mask_img.resize(self.img_wh, Image.LANCZOS) + img = self.transform(img) + + mask_img = self.transform(mask_img) + + valid_mask = (img[-1] > 0).flatten() + img = img.view(3, -1).permute(1, 0) + mask_img = mask_img.view(3, -1).permute(1, 0)[..., :1] + + rays_o, rays_d = get_rays(self.directions, c2w) + + rays_o = rays_o.to(self.device) + rays_d = rays_d.to(self.device) + + sample = { + "rays_o": rays_o, + "rays_d": rays_d, + "rgbs": img, + "mask": mask_img, + "H": self.img_wh[1], + "W": self.img_wh[0], + "c2w": c2w, + # TODO: return dirs(方向) + "valid_mask": valid_mask} + return sample diff --git a/nerf/clip.py b/nerf/clip.py new file mode 100755 index 0000000..0f815e1 --- /dev/null +++ b/nerf/clip.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn + +import torchvision.transforms as T +import torchvision.transforms.functional as TF + +import clip + +class CLIP(nn.Module): + def __init__(self, device): + super().__init__() + self.device = device + self.model, self.preprocess = clip.load("ViT-B/32", device=device) + self.model.cuda().eval() + self.transformCLIP = T.Compose([ + T.Resize(size=224, interpolation= T.InterpolationMode.BICUBIC, max_size=None, antialias=None), + T.CenterCrop(size=(224,224)), + T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + + ]) + + + + def get_text_embeds(self, text): + text_tokens = clip.tokenize(text).cuda() + with torch.no_grad(): + text_feature = self.model.encode_text(text_tokens).float() + return text_feature + + def encode_img(self, img): + img = self.transformCLIP(img) + image_feature = self.model.encode_image(img).float() + return image_feature + + + + + diff --git a/nerf/gui.py b/nerf/gui.py new file mode 100755 index 0000000..c69f98c --- /dev/null +++ b/nerf/gui.py @@ -0,0 +1,469 @@ +import math +import torch +import numpy as np +import dearpygui.dearpygui as dpg +from scipy.spatial.transform import Rotation as R + +from nerf.utils import * + + +class OrbitCamera: + def __init__(self, W, H, r=2, fovy=60): + self.W = W + self.H = H + self.radius = r # camera distance from center + self.fovy = fovy # in degree + self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point + self.rot = R.from_quat([1, 0, 0, 0]) # init camera matrix: [[1, 0, 0], [0, -1, 0], [0, 0, 1]] (to suit ngp convention) + self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized! + + # pose + @property + def pose(self): + # first move camera to radius + res = np.eye(4, dtype=np.float32) + res[2, 3] -= self.radius + # rotate + rot = np.eye(4, dtype=np.float32) + rot[:3, :3] = self.rot.as_matrix() + res = rot @ res + # translate + res[:3, 3] -= self.center + return res + + # intrinsics + @property + def intrinsics(self): + focal = self.H / (2 * np.tan(np.deg2rad(self.fovy) / 2)) + return np.array([focal, focal, self.W // 2, self.H // 2]) + + + def orbit(self, dx, dy): + # rotate along camera up/side axis! + side = self.rot.as_matrix()[:3, 0] # why this is side --> ? # already normalized. + rotvec_x = self.up * np.deg2rad(-0.1 * dx) + rotvec_y = side * np.deg2rad(-0.1 * dy) + self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot + + def scale(self, delta): + self.radius *= 1.1 ** (-delta) + + def pan(self, dx, dy, dz=0): + # pan in camera coordinate system (careful on the sensitivity!) + self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([dx, dy, dz]) + + +class NeRFGUI: + def __init__(self, opt, trainer, debug=True): + self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. + self.W = opt.W + self.H = opt.H + self.cam = OrbitCamera(self.W, self.H, r=opt.radius, fovy=opt.fovy) + self.debug = debug + self.bg_color = torch.ones(3, dtype=torch.float32) # default white bg + self.training = False + self.step = 0 # training step + + self.trainer = trainer + self.render_buffer = np.zeros((self.W, self.H, 3), dtype=np.float32) + self.need_update = True # camera moved, should reset accumulation + self.spp = 1 # sample per pixel + self.light_dir = np.array([opt.light_theta, opt.light_phi]) + self.ambient_ratio = 1.0 + self.mode = 'image' # choose from ['image', 'depth'] + self.shading = 'albedo' + + self.dynamic_resolution = True + self.downscale = 1 + self.train_steps = 16 + + dpg.create_context() + self.register_dpg() + self.test_step() + + + def __del__(self): + dpg.destroy_context() + + + def train_step(self): + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + outputs = self.trainer.train_gui(self.trainer.train_loader, step=self.train_steps) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + self.step += self.train_steps + self.need_update = True + + dpg.set_value("_log_train_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_train_log", f'step = {self.step: 5d} (+{self.train_steps: 2d}), loss = {outputs["loss"]:.4f}, lr = {outputs["lr"]:.5f}') + + # dynamic train steps + # max allowed train time per-frame is 500 ms + full_t = t / self.train_steps * 16 + train_steps = min(16, max(4, int(16 * 500 / full_t))) + if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8: + self.train_steps = train_steps + + + def prepare_buffer(self, outputs): + if self.mode == 'image': + return outputs['image'] + else: + return np.expand_dims(outputs['depth'], -1).repeat(3, -1) + + + def test_step(self): + + if self.need_update or self.spp < self.opt.max_spp: + + starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + starter.record() + + outputs = self.trainer.test_gui(self.cam.pose, self.cam.intrinsics, self.W, self.H, self.bg_color, self.spp, self.downscale, self.light_dir, self.ambient_ratio, self.shading) + + ender.record() + torch.cuda.synchronize() + t = starter.elapsed_time(ender) + + # update dynamic resolution + if self.dynamic_resolution: + # max allowed infer time per-frame is 200 ms + full_t = t / (self.downscale ** 2) + downscale = min(1, max(1/4, math.sqrt(200 / full_t))) + if downscale > self.downscale * 1.2 or downscale < self.downscale * 0.8: + self.downscale = downscale + + if self.need_update: + self.render_buffer = self.prepare_buffer(outputs) + self.spp = 1 + self.need_update = False + else: + self.render_buffer = (self.render_buffer * self.spp + self.prepare_buffer(outputs)) / (self.spp + 1) + self.spp += 1 + + dpg.set_value("_log_infer_time", f'{t:.4f}ms ({int(1000/t)} FPS)') + dpg.set_value("_log_resolution", f'{int(self.downscale * self.W)}x{int(self.downscale * self.H)}') + dpg.set_value("_log_spp", self.spp) + dpg.set_value("_texture", self.render_buffer) + + + def register_dpg(self): + + ### register texture + + with dpg.texture_registry(show=False): + dpg.add_raw_texture(self.W, self.H, self.render_buffer, format=dpg.mvFormat_Float_rgb, tag="_texture") + + ### register window + + # the rendered image, as the primary window + with dpg.window(tag="_primary_window", width=self.W, height=self.H): + + # add the texture + dpg.add_image("_texture") + + dpg.set_primary_window("_primary_window", True) + + # control window + with dpg.window(label="Control", tag="_control_window", width=400, height=300): + + # text prompt + if self.opt.text is not None: + dpg.add_text("text: " + self.opt.text, tag="_log_prompt_text") + + if self.opt.negative != '': + dpg.add_text("negative text: " + self.opt.negative, tag="_log_prompt_negative_text") + + # button theme + with dpg.theme() as theme_button: + with dpg.theme_component(dpg.mvButton): + dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) + dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) + dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) + + # time + if not self.opt.test: + with dpg.group(horizontal=True): + dpg.add_text("Train time: ") + dpg.add_text("no data", tag="_log_train_time") + + with dpg.group(horizontal=True): + dpg.add_text("Infer time: ") + dpg.add_text("no data", tag="_log_infer_time") + + with dpg.group(horizontal=True): + dpg.add_text("SPP: ") + dpg.add_text("1", tag="_log_spp") + + # train button + if not self.opt.test: + with dpg.collapsing_header(label="Train", default_open=True): + with dpg.group(horizontal=True): + dpg.add_text("Train: ") + + def callback_train(sender, app_data): + if self.training: + self.training = False + dpg.configure_item("_button_train", label="start") + else: + self.training = True + dpg.configure_item("_button_train", label="stop") + + dpg.add_button(label="start", tag="_button_train", callback=callback_train) + dpg.bind_item_theme("_button_train", theme_button) + + def callback_reset(sender, app_data): + @torch.no_grad() + def weight_reset(m: nn.Module): + reset_parameters = getattr(m, "reset_parameters", None) + if callable(reset_parameters): + m.reset_parameters() + self.trainer.model.apply(fn=weight_reset) + self.trainer.model.reset_extra_state() # for cuda_ray density_grid and step_counter + self.need_update = True + + dpg.add_button(label="reset", tag="_button_reset", callback=callback_reset) + dpg.bind_item_theme("_button_reset", theme_button) + + + with dpg.group(horizontal=True): + dpg.add_text("Checkpoint: ") + + def callback_save(sender, app_data): + self.trainer.save_checkpoint(full=True, best=False) + dpg.set_value("_log_ckpt", "saved " + os.path.basename(self.trainer.stats["checkpoints"][-1])) + self.trainer.epoch += 1 # use epoch to indicate different calls. + + dpg.add_button(label="save", tag="_button_save", callback=callback_save) + dpg.bind_item_theme("_button_save", theme_button) + + dpg.add_text("", tag="_log_ckpt") + + # save mesh + with dpg.group(horizontal=True): + dpg.add_text("Marching Cubes: ") + + def callback_mesh(sender, app_data): + self.trainer.save_mesh(resolution=256) + dpg.set_value("_log_mesh", "saved " + f'{self.trainer.name}_{self.trainer.epoch}.ply') + self.trainer.epoch += 1 # use epoch to indicate different calls. + + dpg.add_button(label="mesh", tag="_button_mesh", callback=callback_mesh) + dpg.bind_item_theme("_button_mesh", theme_button) + + dpg.add_text("", tag="_log_mesh") + + with dpg.group(horizontal=True): + dpg.add_text("", tag="_log_train_log") + + + # rendering options + with dpg.collapsing_header(label="Options", default_open=True): + + # dynamic rendering resolution + with dpg.group(horizontal=True): + + def callback_set_dynamic_resolution(sender, app_data): + if self.dynamic_resolution: + self.dynamic_resolution = False + self.downscale = 1 + else: + self.dynamic_resolution = True + self.need_update = True + + dpg.add_checkbox(label="dynamic resolution", default_value=self.dynamic_resolution, callback=callback_set_dynamic_resolution) + dpg.add_text(f"{self.W}x{self.H}", tag="_log_resolution") + + # mode combo + def callback_change_mode(sender, app_data): + self.mode = app_data + self.need_update = True + + dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode) + + # bg_color picker + def callback_change_bg(sender, app_data): + self.bg_color = torch.tensor(app_data[:3], dtype=torch.float32) # only need RGB in [0, 1] + self.need_update = True + + dpg.add_color_edit((255, 255, 255), label="Background Color", width=200, tag="_color_editor", no_alpha=True, callback=callback_change_bg) + + # fov slider + def callback_set_fovy(sender, app_data): + self.cam.fovy = app_data + self.need_update = True + + dpg.add_slider_int(label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=self.cam.fovy, callback=callback_set_fovy) + + # dt_gamma slider + def callback_set_dt_gamma(sender, app_data): + self.opt.dt_gamma = app_data + self.need_update = True + + dpg.add_slider_float(label="dt_gamma", min_value=0, max_value=0.1, format="%.5f", default_value=self.opt.dt_gamma, callback=callback_set_dt_gamma) + + # max_steps slider + def callback_set_max_steps(sender, app_data): + self.opt.max_steps = app_data + self.need_update = True + + dpg.add_slider_int(label="max steps", min_value=1, max_value=1024, format="%d", default_value=self.opt.max_steps, callback=callback_set_max_steps) + + # aabb slider + def callback_set_aabb(sender, app_data, user_data): + # user_data is the dimension for aabb (xmin, ymin, zmin, xmax, ymax, zmax) + self.trainer.model.aabb_infer[user_data] = app_data + + # also change train aabb ? [better not...] + #self.trainer.model.aabb_train[user_data] = app_data + + self.need_update = True + + dpg.add_separator() + dpg.add_text("Axis-aligned bounding box:") + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="x", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=0) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=3) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="y", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=1) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=4) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="z", width=150, min_value=-self.opt.bound, max_value=0, format="%.2f", default_value=-self.opt.bound, callback=callback_set_aabb, user_data=2) + dpg.add_slider_float(label="", width=150, min_value=0, max_value=self.opt.bound, format="%.2f", default_value=self.opt.bound, callback=callback_set_aabb, user_data=5) + + # light dir + def callback_set_light_dir(sender, app_data, user_data): + self.light_dir[user_data] = app_data + self.need_update = True + + dpg.add_separator() + dpg.add_text("Plane Light Direction:") + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="theta", min_value=0, max_value=180, format="%.2f", default_value=self.opt.light_theta, callback=callback_set_light_dir, user_data=0) + + with dpg.group(horizontal=True): + dpg.add_slider_float(label="phi", min_value=0, max_value=360, format="%.2f", default_value=self.opt.light_phi, callback=callback_set_light_dir, user_data=1) + + # ambient ratio + def callback_set_abm_ratio(sender, app_data): + self.ambient_ratio = app_data + self.need_update = True + + dpg.add_slider_float(label="ambient", min_value=0, max_value=1.0, format="%.5f", default_value=self.ambient_ratio, callback=callback_set_abm_ratio) + + # shading mode + def callback_change_shading(sender, app_data): + self.shading = app_data + self.need_update = True + + dpg.add_combo(('albedo', 'lambertian', 'textureless', 'normal'), label='shading', default_value=self.shading, callback=callback_change_shading) + + + # debug info + if self.debug: + with dpg.collapsing_header(label="Debug"): + # pose + dpg.add_separator() + dpg.add_text("Camera Pose:") + dpg.add_text(str(self.cam.pose), tag="_log_pose") + + + ### register camera handler + + def callback_camera_drag_rotate(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.orbit(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_wheel_scale(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + delta = app_data + + self.cam.scale(delta) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + def callback_camera_drag_pan(sender, app_data): + + if not dpg.is_item_focused("_primary_window"): + return + + dx = app_data[1] + dy = app_data[2] + + self.cam.pan(dx, dy) + self.need_update = True + + if self.debug: + dpg.set_value("_log_pose", str(self.cam.pose)) + + + with dpg.handler_registry(): + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate) + dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) + dpg.add_mouse_drag_handler(button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan) + + + dpg.create_viewport(title='torch-ngp', width=self.W, height=self.H, resizable=False) + + # TODO: seems dearpygui doesn't support resizing texture... + # def callback_resize(sender, app_data): + # self.W = app_data[0] + # self.H = app_data[1] + # # how to reload texture ??? + + # dpg.set_viewport_resize_callback(callback_resize) + + ### global theme + with dpg.theme() as theme_no_padding: + with dpg.theme_component(dpg.mvAll): + # set all padding to 0 to avoid scroll bar + dpg.add_theme_style(dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core) + dpg.add_theme_style(dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core) + + dpg.bind_item_theme("_primary_window", theme_no_padding) + + dpg.setup_dearpygui() + + #dpg.show_metrics() + + dpg.show_viewport() + + + def render(self): + + while dpg.is_dearpygui_running(): + # update texture every frame + if self.training: + self.train_step() + self.test_step() + dpg.render_dearpygui_frame() \ No newline at end of file diff --git a/nerf/llff.py b/nerf/llff.py new file mode 100755 index 0000000..784eb28 --- /dev/null +++ b/nerf/llff.py @@ -0,0 +1,573 @@ +import os +import json + +import torch +import torch.nn as nn +from torch.utils.data import Dataset +from torchvision import transforms as T + +import numpy as np +from PIL import Image +import cv2 +from einops import rearrange +import imageio +import numpy as np +# from utils import * +# This code is borrowed from https://github.com/kwea123/nerf_pl/ +# I modified the return batch for whole resolution image +import glob + +# LLFFDataset +from kornia import create_meshgrid +def get_ray_directions(H,W,focal): + grid = create_meshgrid(H,W,normalized_coordinates=False)[0] + i, j = grid.unbind(-1) + directions = torch.stack([(i-W/2)/focal, -(j-H/2)/focal, -torch.ones_like(i)], -1) + return directions # (H, W, 3) + + +def get_rays(directions, c2w): + rays_d = directions @ c2w[:, :3].T + rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) + rays_o = c2w[:, 3].expand(rays_d.shape) # H, W, 3 + rays_d = rays_d.view(-1, 3) # H*W, 3 + rays_o = rays_o.view(-1, 3) + + return rays_o, rays_d + +def get_ndc_rays(H, W, focal, near, rays_o, rays_d): + """ + Transform rays from world coordinate to NDC. + NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis. + For detailed derivation, please see: + http://www.songho.ca/opengl/gl_projectionmatrix.html + https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf + + In practice, use NDC "if and only if" the scene is unbounded (has a large depth). + See https://github.com/bmild/nerf/issues/18 + + Inputs: + H, W, focal: image height, width and focal length + near: (N_rays) or float, the depths of the near plane + rays_o: (N_rays, 3), the origin of the rays in world coordinate + rays_d: (N_rays, 3), the direction of the rays in world coordinate + + Outputs: + rays_o: (N_rays, 3), the origin of the rays in NDC + rays_d: (N_rays, 3), the direction of the rays in NDC + """ + # Shift ray origins to near plane + t = -(near + rays_o[...,2]) / rays_d[...,2] + rays_o = rays_o + t[...,None] * rays_d + + # Store some intermediate homogeneous results + ox_oz = rays_o[...,0] / rays_o[...,2] + oy_oz = rays_o[...,1] / rays_o[...,2] + + # Projection + o0 = -1./(W/(2.*focal)) * ox_oz + o1 = -1./(H/(2.*focal)) * oy_oz + o2 = 1. + 2. * near / rays_o[...,2] + + d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - ox_oz) + d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - oy_oz) + d2 = 1 - o2 + + rays_o = torch.stack([o0, o1, o2], -1) # (B, 3) + rays_d = torch.stack([d0, d1, d2], -1) # (B, 3) + + return rays_o, rays_d + +def normalize(v): + """Normalize a vector.""" + return v/np.linalg.norm(v) + + +def average_poses(poses): + """ + Calculate the average pose, which is then used to center all poses + using @center_poses. Its computation is as follows: + 1. Compute the center: the average of pose centers. + 2. Compute the z axis: the normalized average z axis. + 3. Compute axis y': the average y axis. + 4. Compute x' = y' cross product z, then normalize it as the x axis. + 5. Compute the y axis: z cross product x. + + Note that at step 3, we cannot directly use y' as y axis since it's + not necessarily orthogonal to z axis. We need to pass from x to y. + + Inputs: + poses: (N_images, 3, 4) + + Outputs: + pose_avg: (3, 4) the average pose + """ + # 1. Compute the center + center = poses[..., 3].mean(0) # (3) + + # 2. Compute the z axis + z = normalize(poses[..., 2].mean(0)) # (3) + + # 3. Compute axis y' (no need to normalize as it's not the final output) + y_ = poses[..., 1].mean(0) # (3) + + # 4. Compute the x axis + x = normalize(np.cross(y_, z)) # (3) + + # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) + y = np.cross(z, x) # (3) + + pose_avg = np.stack([x, y, z, center], 1) # (3, 4) + + return pose_avg + + +def center_poses(poses): + """ + Center the poses so that we can use NDC. + See https://github.com/bmild/nerf/issues/34 + + Inputs: + poses: (N_images, 3, 4) + + Outputs: + poses_centered: (N_images, 3, 4) the centered poses + pose_avg: (3, 4) the average pose + """ + + pose_avg = average_poses(poses) # (3, 4) + pose_avg_homo = np.eye(4) + pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation + # by simply adding 0, 0, 0, 1 as the last row + last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) + poses_homo = \ + np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate + + poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) + poses_centered = poses_centered[:, :3] # (N_images, 3, 4) + + return poses_centered, np.linalg.inv(pose_avg_homo) + + +def create_spiral_poses(radii, focus_depth, n_poses=120): + """ + Computes poses that follow a spiral path for rendering purpose. + See https://github.com/Fyusion/LLFF/issues/19 + In particular, the path looks like: + https://tinyurl.com/ybgtfns3 + + Inputs: + radii: (3) radii of the spiral for each axis + focus_depth: float, the depth that the spiral poses look at + n_poses: int, number of poses to create along the path + + Outputs: + poses_spiral: (n_poses, 3, 4) the poses in the spiral path + """ + + poses_spiral = [] + for t in np.linspace(0, 4*np.pi, n_poses+1)[:-1]: # rotate 4pi (2 rounds) + # the parametric function of the spiral (see the interactive web) + center = np.array([np.cos(t), -np.sin(t), -np.sin(0.5*t)]) * radii + + # the viewing z axis is the vector pointing from the @focus_depth plane + # to @center + z = normalize(center - np.array([0, 0, -focus_depth])) + + # compute other axes as in @average_poses + y_ = np.array([0, 1, 0]) # (3) + x = normalize(np.cross(y_, z)) # (3) + y = np.cross(z, x) # (3) + + poses_spiral += [np.stack([x, y, z, center], 1)] # (3, 4) + + return np.stack(poses_spiral, 0) # (n_poses, 3, 4) + + +def create_spheric_poses(radius, n_poses=120): + """ + Create circular poses around z axis. + Inputs: + radius: the (negative) height and the radius of the circle. + + Outputs: + spheric_poses: (n_poses, 3, 4) the poses in the circular path + """ + def spheric_pose(theta, phi, radius): + trans_t = lambda t : np.array([ + [1,0,0,0], + [0,1,0,-0.9*t], + [0,0,1,t], + [0,0,0,1], + ]) + + rot_phi = lambda phi : np.array([ + [1,0,0,0], + [0,np.cos(phi),-np.sin(phi),0], + [0,np.sin(phi), np.cos(phi),0], + [0,0,0,1], + ]) + + rot_theta = lambda th : np.array([ + [np.cos(th),0,-np.sin(th),0], + [0,1,0,0], + [np.sin(th),0, np.cos(th),0], + [0,0,0,1], + ]) + + c2w = rot_theta(theta) @ rot_phi(phi) @ trans_t(radius) + c2w = np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]]) @ c2w + return c2w[:3] + + spheric_poses = [] + for th in np.linspace(0, 2*np.pi, n_poses+1)[:-1]: + spheric_poses += [spheric_pose(th, -np.pi/5, radius)] # 36 degree view downwards + return np.stack(spheric_poses, 0) + + +def read_depth_image(path, img_wh): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + # print(path) + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + img = cv2.resize(img, img_wh) + img = rearrange(img, 'h w c -> (h w) c') + return img # H W C + +# read color images or depth images +def read_image(img_path, img_wh, blend_a=True): + img = imageio.imread(img_path).astype(np.float32)/255.0 + # img[..., :3] = srgb_to_linear(img[..., :3]) + + if img.shape[2] == 4: # blend A to RGB + if blend_a: + img = img[..., :3]*img[..., -1:]+(1-img[..., -1:]) + else: + img = img[..., :3]*img[..., -1:] + + img = cv2.resize(img, img_wh) + img = rearrange(img, 'h w c -> (h w) c') + + return img + +class LLFFDataset(Dataset): + def __init__(self, device, root_dir, split='train', img_wh=(504, 378), spheric_poses=False, val_num=1): + """ + spheric_poses: whether the images are taken in a spheric inward-facing manner + default: False (forward-facing) + val_num: number of val images (used for multigpu training, validate same image for all gpus) + """ + self.root_dir = root_dir + self.split = split + self.img_wh = img_wh + self.spheric_poses = spheric_poses + self.val_num = max(1, val_num) # at least 1 + self.define_transforms() + self.device = device + + self.read_meta() + self.white_back = False + + def read_meta(self): + poses_bounds = np.load(os.path.join(self.root_dir, + 'poses_bounds.npy')) # (N_images, 17) + self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images/*'))) + self.mask_paths = sorted(glob.glob(os.path.join(self.root_dir, 'train_mask/*'))) + self.depth_paths = sorted(glob.glob(os.path.join(self.root_dir, "train_depth/*"))) + # load full resolution image then resize + if self.split in ['train', 'val']: + assert len(poses_bounds) == len(self.image_paths), \ + 'Mismatch between number of images and number of poses! Please rerun COLMAP!' + + poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) + self.bounds = poses_bounds[:, -2:] # (N_images, 2) + + # Step 1: rescale focal length according to training resolution + H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images + # assert H*self.img_wh[0] == W*self.img_wh[1], \ + # f'You must set @img_wh to have the same aspect ratio as ({W}, {H}) !' + + self.focal *= self.img_wh[0]/W + + # Step 2: correct poses + # Original poses has rotation in form "down right back", change to "right up back" + # See https://github.com/bmild/nerf/issues/34 + poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) + # (N_images, 3, 4) exclude H, W, focal + self.poses, self.pose_avg = center_poses(poses) + distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1) + val_idx = np.argmin(distances_from_center) # choose val image as the closest to + # center image + + # Step 3: correct scale so that the nearest depth is at a little more than 1.0 + # See https://github.com/bmild/nerf/issues/34 + near_original = self.bounds.min() + scale_factor = near_original*0.75 # 0.75 is the default parameter + # the nearest depth is at 1/0.75=1.33 + self.bounds /= scale_factor + self.poses[..., 3] /= scale_factor + + # ray directions for all pixels, same for all images (same H, W, focal) + self.directions = \ + get_ray_directions(self.img_wh[1], self.img_wh[0], self.focal) # (H, W, 3) + + if self.split.endswith('train'): # test on training set + self.poses_test = self.poses + elif not self.spheric_poses: + focus_depth = 3.5 # hardcoded, this is numerically close to the formula + # given in the original repo. Mathematically if near=1 + # and far=infinity, then this number will converge to 4 + radii = np.percentile(np.abs(self.poses[..., 3]), 90, axis=0) + self.poses_test = create_spiral_poses(radii, focus_depth) + else: + radius = 1.1 * self.bounds.min() + self.poses_test = create_spheric_poses(radius) + + # self.poses_test = self.poses_test.to(device) + + + + def define_transforms(self): + self.transform = T.ToTensor() + + def __len__(self): + if self.split == 'train': + return len(self.poses) + if self.split == 'val': + return self.val_num + return len(self.poses_test) + + def __getitem__(self, idx): + if self.split == 'train' or self.split=='val': # use data in the buffers + c2w = torch.FloatTensor(self.poses[idx]) + else: + c2w = torch.FloatTensor(self.poses_test[idx]) + + rays_o, rays_d = get_rays(self.directions, c2w) + if not self.spheric_poses: + near, far = 0, 1 + rays_o, rays_d = get_ndc_rays(self.img_wh[1], self.img_wh[0], + self.focal, 1.0, rays_o, rays_d) + else: + near = self.bounds.min() + far = min(8 * near, self.bounds.max()) + + + rays_o = rays_o.to(self.device) + rays_d = rays_d.to(self.device) + c2w = c2w.to(self.device) + sample = {'rays_o':rays_o, + 'rays_d':rays_d, + 'c2w': c2w} + + if self.split == 'train' or self.split=='val': + img = Image.open(self.image_paths[idx])#.convert('RGB') + mask = Image.open(self.mask_paths[idx])#.convert('RGB') + img = img.resize(self.img_wh, Image.LANCZOS) + mask = mask.resize(self.img_wh, Image.LANCZOS) + img = self.transform(img) # (3, h, w) + + + mask = self.transform(mask) + + img = img.view(3, -1).permute(1, 0) # (h*w, 3)# TODO: shape + mask = mask.view(3, -1).permute(1, 0) + + + sample['rgbs'] = img + sample['mask'] = mask[:, :1] + sample['W'] = self.img_wh[0] + sample['H'] = self.img_wh[1] + return sample + + + + + + + + + +# class LLFFDataset(Dataset): +# def __init__(self, root_dir, split='train', img_wh=(504, 378), spheric_poses=False, val_num=1): +# """ +# spheric_poses: whether the images are taken in a spheric inward-facing manner +# default: False (forward-facing) +# val_num: number of val images (used for multigpu training, validate same image for all gpus) +# """ +# self.root_dir = root_dir +# self.split = split +# self.img_wh = img_wh +# self.spheric_poses = spheric_poses +# self.val_num = max(1, val_num) # at least 1 +# self.define_transforms() + +# self.read_meta() +# self.white_back = False + +# def read_meta(self): +# poses_bounds = np.load(os.path.join(self.root_dir, +# 'poses_bounds.npy')) # (N_images, 17) +# self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images/*'))) +# # load full resolution image then resize +# if self.split in ['train', 'val']: +# assert len(poses_bounds) == len(self.image_paths), \ +# 'Mismatch between number of images and number of poses! Please rerun COLMAP!' + +# poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) +# self.bounds = poses_bounds[:, -2:] # (N_images, 2) + +# # Step 1: rescale focal length according to training resolution +# H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images +# assert H*self.img_wh[0] == W*self.img_wh[1], \ +# f'You must set @img_wh to have the same aspect ratio as ({W}, {H}) !' + +# self.focal *= self.img_wh[0]/W + +# # Step 2: correct poses +# # Original poses has rotation in form "down right back", change to "right up back" +# # See https://github.com/bmild/nerf/issues/34 +# poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) +# # (N_images, 3, 4) exclude H, W, focal +# self.poses, self.pose_avg = center_poses(poses) +# distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1) +# val_idx = np.argmin(distances_from_center) # choose val image as the closest to +# # center image + +# # Step 3: correct scale so that the nearest depth is at a little more than 1.0 +# # See https://github.com/bmild/nerf/issues/34 +# near_original = self.bounds.min() +# scale_factor = near_original*0.75 # 0.75 is the default parameter +# # the nearest depth is at 1/0.75=1.33 +# self.bounds /= scale_factor +# self.poses[..., 3] /= scale_factor + +# # ray directions for all pixels, same for all images (same H, W, focal) +# self.directions = \ +# get_ray_directions(self.img_wh[1], self.img_wh[0], self.focal) # (H, W, 3) + +# if self.split == 'train': # create buffer of all rays and rgb data +# # use first N_images-1 to train, the LAST is val +# self.all_rays_o = [] +# self.all_rays_d = [] +# self.all_rgbs = [] +# for i, image_path in enumerate(self.image_paths): +# if i == val_idx: # exclude the val image +# continue +# c2w = torch.FloatTensor(self.poses[i]) + +# img = Image.open(image_path).convert('RGB') +# assert img.size[1]*self.img_wh[0] == img.size[0]*self.img_wh[1], \ +# f'''{image_path} has different aspect ratio than img_wh, +# please check your data!''' +# img = img.resize(self.img_wh, Image.LANCZOS) +# img = self.transform(img) # (3, h, w) +# img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB +# self.all_rgbs += [img] + +# rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) +# if not self.spheric_poses: +# near, far = 0, 1 +# rays_o, rays_d = get_ndc_rays(self.img_wh[1], self.img_wh[0], +# self.focal, 1.0, rays_o, rays_d) +# # near plane is always at 1.0 +# # near and far in NDC are always 0 and 1 +# # See https://github.com/bmild/nerf/issues/34 +# else: +# near = self.bounds.min() +# far = min(8 * near, self.bounds.max()) # focus on central object only + +# # self.all_rays += [torch.cat([rays_o, rays_d, +# # near*torch.ones_like(rays_o[:, :1]), +# # far*torch.ones_like(rays_o[:, :1])], +# # 1)] # (h*w, 8) +# self.all_rays_o +=[rays_o] +# self.all_rays_d +=[rays_d] + +# self.all_rays_o = torch.cat(self.all_rays_o, 0) # ((N_images-1)*h*w, 8) +# self.all_rays_d = torch.cat(self.all_rays_d, 0) # ((N_images-1)*h*w, 8) +# self.all_rgbs = torch.cat(self.all_rgbs, 0) # ((N_images-1)*h*w, 3) + +# elif self.split == 'val': +# print('val image is', self.image_paths[val_idx]) +# self.c2w_val = self.poses[val_idx] +# self.image_path_val = self.image_paths[val_idx] + +# else: # for testing, create a parametric rendering path +# if self.split.endswith('train'): # test on training set +# self.poses_test = self.poses +# elif not self.spheric_poses: +# focus_depth = 3.5 # hardcoded, this is numerically close to the formula +# # given in the original repo. Mathematically if near=1 +# # and far=infinity, then this number will converge to 4 +# radii = np.percentile(np.abs(self.poses[..., 3]), 90, axis=0) +# self.poses_test = create_spiral_poses(radii, focus_depth) +# else: +# radius = 1.1 * self.bounds.min() +# self.poses_test = create_spheric_poses(radius) + +# def define_transforms(self): +# self.transform = T.ToTensor() + +# def __len__(self): +# if self.split == 'train': +# return len(self.all_rgbs) +# if self.split == 'val': +# return self.val_num +# return len(self.poses_test) + +# def __getitem__(self, idx): +# if self.split == 'train': # use data in the buffers +# sample = {'rays_o': self.all_rays_o[idx].unsqueeze(0), +# 'rays_d': self.all_rays_d[idx].unsqueeze(0), +# 'rgbs': self.all_rgbs[idx]} + +# else: +# if self.split == 'val': +# c2w = torch.FloatTensor(self.c2w_val) +# else: +# c2w = torch.FloatTensor(self.poses_test[idx]) + +# rays_o, rays_d = get_rays(self.directions, c2w) +# if not self.spheric_poses: +# near, far = 0, 1 +# rays_o, rays_d = get_ndc_rays(self.img_wh[1], self.img_wh[0], +# self.focal, 1.0, rays_o, rays_d) +# else: +# near = self.bounds.min() +# far = min(8 * near, self.bounds.max()) + +# # rays = torch.cat([rays_o, rays_d, +# # near*torch.ones_like(rays_o[:, :1]), +# # far*torch.ones_like(rays_o[:, :1])], +# # 1) # (h*w, 8) + +# sample = {'rays_o':rays_o.unsqueeze(0), +# 'rays_d':rays_d.unsqueeze(0), +# 'c2w': c2w} + +# if self.split == 'val': +# img = Image.open(self.image_path_val).convert('RGB') +# img = img.resize(self.img_wh, Image.LANCZOS) +# img = self.transform(img) # (3, h, w) +# img = img.view(3, -1).permute(1, 0) # (h*w, 3) +# sample['rgbs'] = img.unsqueeze(0) +# sample['W'] = self.img_wh[0] +# sample['H'] = self.img_wh[1] +# return sample + + + + + diff --git a/nerf/network_grid.py b/nerf/network_grid.py new file mode 100755 index 0000000..752c468 --- /dev/null +++ b/nerf/network_grid.py @@ -0,0 +1,193 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from activation import trunc_exp +from .renderer import NeRFRenderer + +import numpy as np +from encoding import get_encoder + +from .utils import safe_normalize + +class MLP(nn.Module): + def __init__(self, dim_in, dim_out, dim_hidden, num_layers, bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for l in range(num_layers): + net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=bias)) + + self.net = nn.ModuleList(net) + + def forward(self, x): + for l in range(self.num_layers): + x = self.net[l](x) + if l != self.num_layers - 1: + x = F.relu(x, inplace=True) + return x + + +class NeRFNetwork(NeRFRenderer): + def __init__(self, + opt, + num_layers=2, + hidden_dim=32, + num_layers_bg=2, + hidden_dim_bg=64, + ): + + super().__init__(opt) + + self.num_layers = num_layers + self.hidden_dim = hidden_dim + + self.encoder, self.in_dim = get_encoder('hashgrid', input_dim=3, log2_hashmap_size=16, desired_resolution=1024 * self.bound) + self.encoder_dir, self.in_dim_dir = get_encoder('sphere_harmonics', input_dim=3) + self.sigma_net = MLP(self.in_dim, 1, hidden_dim, num_layers, bias=True) + self.color_net = MLP(self.in_dim+self.in_dim_dir, 3, hidden_dim, num_layers, bias=True) + + # background network + if self.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + + # use a very simple network to avoid it learning the prompt... + # self.encoder_bg, self.in_dim_bg = get_encoder('tiledgrid', input_dim=2, num_levels=4, desired_resolution=2048) + self.encoder_bg, self.in_dim_bg = get_encoder('frequency', input_dim=3, multires=4) + + self.bg_net = MLP(self.in_dim_bg, 3, hidden_dim_bg, num_layers_bg, bias=True) + + else: + self.bg_net = None + + # add a density blob to the scene center + def gaussian(self, x): + # x: [B, N, 3] + + d = (x ** 2).sum(-1) + g = self.opt.blob_density * torch.exp(- d / (self.opt.blob_radius ** 2)) + + return g + + def common_forward(self, x, d): + # x: [N, 3], in [-bound, bound] + + # sigma + # print("x:", x.shape) + h = self.encoder(x, bound=self.bound) + # print(self.bound) + # print("h",h.shape) + sigma = self.sigma_net(h) + + sigma = trunc_exp(sigma) + h_d = self.encoder_dir(d) + albedo = torch.sigmoid(self.color_net(torch.cat([h, h_d], -1))) + # albedo = torch.sigmoid(h[..., 1:]) + + return sigma, albedo + + # ref: https://github.com/zhaofuq/Instant-NSR/blob/main/nerf/network_sdf.py#L192 + def finite_difference_normal(self, x, epsilon=1e-2): + # x: [N, 3] + dx_pos, _ = self.common_forward((x + torch.tensor([[epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dx_neg, _ = self.common_forward((x + torch.tensor([[-epsilon, 0.00, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dy_pos, _ = self.common_forward((x + torch.tensor([[0.00, epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dy_neg, _ = self.common_forward((x + torch.tensor([[0.00, -epsilon, 0.00]], device=x.device)).clamp(-self.bound, self.bound)) + dz_pos, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, epsilon]], device=x.device)).clamp(-self.bound, self.bound)) + dz_neg, _ = self.common_forward((x + torch.tensor([[0.00, 0.00, -epsilon]], device=x.device)).clamp(-self.bound, self.bound)) + + normal = torch.stack([ + 0.5 * (dx_pos - dx_neg) / epsilon, + 0.5 * (dy_pos - dy_neg) / epsilon, + 0.5 * (dz_pos - dz_neg) / epsilon + ], dim=-1) + + return -normal + + + def normal(self, x): + + normal = self.finite_difference_normal(x) + normal = safe_normalize(normal) + normal = torch.nan_to_num(normal) + + return normal + + + def forward(self, x, d, l=None, ratio=1, shading='albedo'): + # x: [N, 3], in [-bound, bound] + # d: [N, 3], view direction, nomalized in [-1, 1] + # l: [3], plane light direction, nomalized in [-1, 1] + # ratio: scalar, ambient ratio, 1 == no shading (albedo only), 0 == only shading (textureless) + + if shading == 'albedo': + # no need to query normal + sigma, color = self.common_forward(x, d) + normal = None + + else: + # query normal + + sigma, albedo = self.common_forward(x) + normal = self.normal(x) + + # lambertian shading + lambertian = ratio + (1 - ratio) * (normal @ l).clamp(min=0) # [N,] + + if shading == 'textureless': + color = lambertian.unsqueeze(-1).repeat(1, 3) + elif shading == 'normal': + color = (normal + 1) / 2 + else: # 'lambertian' + color = albedo * lambertian.unsqueeze(-1) + + return sigma, color, normal + + + def density(self, x): + # x: [N, 3], in [-bound, bound] + + h = self.encoder(x, bound=self.bound) + # print(self.bound) + # print("h",h.shape) + sigma = self.sigma_net(h) + + sigma = trunc_exp(sigma) + + return { + 'sigma': sigma, + # 'albedo': albedo, + } + + + def background(self, d): + + h = self.encoder_bg(d) # [N, C] + + h = self.bg_net(h) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + # optimizer utils + def get_params(self, lr): + + params = [ + {'params': self.encoder.parameters(), 'lr': lr * 10}, + {'params': self.encoder_dir.parameters(), 'lr': lr}, + {'params': self.sigma_net.parameters(), 'lr': lr}, + {'params': self.color_net.parameters(), 'lr': lr}, + ] + + if self.bg_radius > 0: + # params.append({'params': self.encoder_bg.parameters(), 'lr': lr * 10}) + params.append({'params': self.bg_net.parameters(), 'lr': lr}) + + return params \ No newline at end of file diff --git a/nerf/renderer.py b/nerf/renderer.py new file mode 100755 index 0000000..6e30126 --- /dev/null +++ b/nerf/renderer.py @@ -0,0 +1,688 @@ +import os +import math +import cv2 +import trimesh +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import mcubes +import raymarching +from .utils import custom_meshgrid, safe_normalize + +def sample_pdf(bins, weights, n_samples, det=False): + # This implementation is from NeRF + # bins: [B, T], old_z_vals + # weights: [B, T - 1], bin weights. + # return: [B, n_samples], new_z_vals + + # Get pdf + weights = weights + 1e-5 # prevent nans + pdf = weights / torch.sum(weights, -1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) + # Take uniform samples + if det: + u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(weights.device) + u = u.expand(list(cdf.shape[:-1]) + [n_samples]) + else: + u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) + + # Invert CDF + u = u.contiguous() + inds = torch.searchsorted(cdf, u, right=True) + below = torch.max(torch.zeros_like(inds - 1), inds - 1) + above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) + inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) + + matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] + cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) + bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) + + denom = (cdf_g[..., 1] - cdf_g[..., 0]) + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u - cdf_g[..., 0]) / denom + samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) + + return samples + +@torch.cuda.amp.autocast(enabled=False) +def near_far_from_bound(rays_o, rays_d, bound, type='cube', min_near=0.05): + # rays: [B, N, 3], [B, N, 3] + # bound: int, radius for ball or half-edge-length for cube + # return near [B, N, 1], far [B, N, 1] + + radius = rays_o.norm(dim=-1, keepdim=True) + + if type == 'sphere': + near = radius - bound # [B, N, 1] + far = radius + bound + + elif type == 'cube': + tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3] + tmax = (bound - rays_o) / (rays_d + 1e-15) + near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0] + far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0] + # if far < near, means no intersection, set both near and far to inf (1e9 here) + mask = far < near + near[mask] = 1e9 + far[mask] = 1e9 + # restrict near to a minimal value + near = torch.clamp(near, min=min_near) + + return near, far + + +def plot_pointcloud(pc, color=None): + # pc: [N, 3] + # color: [N, 3/4] + print('[visualize points]', pc.shape, pc.dtype, pc.min(0), pc.max(0)) + pc = trimesh.PointCloud(pc, color) + # axis + axes = trimesh.creation.axis(axis_length=4) + # sphere + sphere = trimesh.creation.icosphere(radius=1) + trimesh.Scene([pc, axes, sphere]).show() + + +class NeRFRenderer(nn.Module): + def __init__(self, opt): + super().__init__() + + self.opt = opt + self.bound = opt.bound + self.cascade = 1 + math.ceil(math.log2(opt.bound)) + self.grid_size = 128 + self.cuda_ray = opt.cuda_ray + self.min_near = opt.min_near + self.density_thresh = opt.density_thresh + self.bg_radius = opt.bg_radius + + # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) + # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. + aabb_train = torch.FloatTensor([-opt.bound, -opt.bound, -opt.bound, opt.bound, opt.bound, opt.bound]) + aabb_infer = aabb_train.clone() + self.register_buffer('aabb_train', aabb_train) + self.register_buffer('aabb_infer', aabb_infer) + + # extra state for cuda raymarching + if self.cuda_ray: + # density grid + density_grid = torch.zeros([self.cascade, self.grid_size ** 3]) # [CAS, H * H * H] + density_bitfield = torch.zeros(self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8) # [CAS * H * H * H // 8] + self.register_buffer('density_grid', density_grid) + self.register_buffer('density_bitfield', density_bitfield) + self.mean_density = 0 + self.iter_density = 0 + # step counter + step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... + self.register_buffer('step_counter', step_counter) + self.mean_count = 0 + self.local_step = 0 + + + def forward(self, x, d): + raise NotImplementedError() + + def density(self, x): + raise NotImplementedError() + + def color(self, x, d, mask=None, **kwargs): + raise NotImplementedError() + + def reset_extra_state(self): + if not self.cuda_ray: + return + # density grid + self.density_grid.zero_() + self.mean_density = 0 + self.iter_density = 0 + # step counter + self.step_counter.zero_() + self.mean_count = 0 + self.local_step = 0 + + @torch.no_grad() + def export_mesh(self, path, resolution=None, S=128): + + if resolution is None: + resolution = self.grid_size + + if self.cuda_ray: + density_thresh = min(self.mean_density, self.density_thresh) + else: + density_thresh = self.density_thresh + + sigmas = np.zeros([resolution, resolution, resolution], dtype=np.float32) + + # query + X = torch.linspace(-1, 1, resolution).split(S) + Y = torch.linspace(-1, 1, resolution).split(S) + Z = torch.linspace(-1, 1, resolution).split(S) + + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = custom_meshgrid(xs, ys, zs) + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [S, 3] + val = self.density(pts.to(self.aabb_train.device)) + sigmas[xi * S: xi * S + len(xs), yi * S: yi * S + len(ys), zi * S: zi * S + len(zs)] = val['sigma'].reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [S, 1] --> [x, y, z] + + vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh) + + vertices = vertices / (resolution - 1.0) * 2 - 1 + vertices = vertices.astype(np.float32) + triangles = triangles.astype(np.int32) + + v = torch.from_numpy(vertices).to(self.aabb_train.device) + f = torch.from_numpy(triangles).int().to(self.aabb_train.device) + + # mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault... + # mesh.export(os.path.join(path, f'mesh.ply')) + + # texture? + def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''): + # v, f: torch Tensor + device = v.device + v_np = v.cpu().numpy() # [N, 3] + f_np = f.cpu().numpy() # [M, 3] + + print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}') + + # unwrap uvs + import xatlas + import nvdiffrast.torch as dr + from sklearn.neighbors import NearestNeighbors + from scipy.ndimage import binary_dilation, binary_erosion + + glctx = dr.RasterizeCudaContext() + + atlas = xatlas.Atlas() + atlas.add_mesh(v_np, f_np) + chart_options = xatlas.ChartOptions() + chart_options.max_iterations = 0 # disable merge_chart for faster unwrap... + atlas.generate(chart_options=chart_options) + vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] + + # vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2] + + vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device) + ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device) + + # render uv maps + uv = vt * 2.0 - 1.0 # uvs to range [-1, 1] + uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] + + if ssaa > 1: + h = int(h0 * ssaa) + w = int(w0 * ssaa) + else: + h, w = h0, w0 + + rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4] + xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3] + mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1] + + # masked query + xyzs = xyzs.view(-1, 3) + mask = (mask > 0).view(-1) + + sigmas = torch.zeros(h * w, device=device, dtype=torch.float32) + feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32) + + if mask.any(): + xyzs = xyzs[mask] # [M, 3] + + # batched inference to avoid OOM + all_sigmas = [] + all_feats = [] + head = 0 + while head < xyzs.shape[0]: + tail = min(head + 640000, xyzs.shape[0]) + results_ = self.density(xyzs[head:tail]) + all_sigmas.append(results_['sigma'].float()) + all_feats.append(results_['albedo'].float()) + head += 640000 + + sigmas[mask] = torch.cat(all_sigmas, dim=0) + feats[mask] = torch.cat(all_feats, dim=0) + + sigmas = sigmas.view(h, w, 1) + feats = feats.view(h, w, -1) + mask = mask.view(h, w) + + ### alpha mask + # deltas = 2 * np.sqrt(3) / 1024 + # alphas = 1 - torch.exp(-sigmas * deltas) + # alphas_mask = alphas > 0.5 + # feats = feats * alphas_mask + + # quantize [0.0, 1.0] to [0, 255] + feats = feats.cpu().numpy() + feats = (feats * 255).astype(np.uint8) + + # alphas = alphas.cpu().numpy() + # alphas = (alphas * 255).astype(np.uint8) + + ### NN search as an antialiasing ... + mask = mask.cpu().numpy() + + inpaint_region = binary_dilation(mask, iterations=3) + inpaint_region[mask] = 0 + + search_region = mask.copy() + not_search_region = binary_erosion(search_region, iterations=2) + search_region[not_search_region] = 0 + + search_coords = np.stack(np.nonzero(search_region), axis=-1) + inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1) + + knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords) + _, indices = knn.kneighbors(inpaint_coords) + + feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)] + + # do ssaa after the NN search, in numpy + feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR) + + if ssaa > 1: + # alphas = cv2.resize(alphas, (w0, h0), interpolation=cv2.INTER_NEAREST) + feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR) + + # cv2.imwrite(os.path.join(path, f'alpha.png'), alphas) + cv2.imwrite(os.path.join(path, f'{name}albedo.png'), feats) + + # save obj (v, vt, f /) + obj_file = os.path.join(path, f'{name}mesh.obj') + mtl_file = os.path.join(path, f'{name}mesh.mtl') + + print(f'[INFO] writing obj mesh to {obj_file}') + with open(obj_file, "w") as fp: + fp.write(f'mtllib {name}mesh.mtl \n') + + print(f'[INFO] writing vertices {v_np.shape}') + for v in v_np: + fp.write(f'v {v[0]} {v[1]} {v[2]} \n') + + print(f'[INFO] writing vertices texture coords {vt_np.shape}') + for v in vt_np: + fp.write(f'vt {v[0]} {1 - v[1]} \n') + + print(f'[INFO] writing faces {f_np.shape}') + fp.write(f'usemtl mat0 \n') + for i in range(len(f_np)): + fp.write(f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n") + + with open(mtl_file, "w") as fp: + fp.write(f'newmtl mat0 \n') + fp.write(f'Ka 1.000000 1.000000 1.000000 \n') + fp.write(f'Kd 1.000000 1.000000 1.000000 \n') + fp.write(f'Ks 0.000000 0.000000 0.000000 \n') + fp.write(f'Tr 1.000000 \n') + fp.write(f'illum 1 \n') + fp.write(f'Ns 0.000000 \n') + fp.write(f'map_Kd {name}albedo.png \n') + + _export(v, f) + + def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # bg_color: [BN, 3] in range [0, 1] + # return: image: [B, N, 3], depth: [B, N] + + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + + results = {} + + # choose aabb + aabb = self.aabb_train if self.training else self.aabb_infer + + # sample steps + # nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, aabb, self.min_near) + # nears.unsqueeze_(-1) + # fars.unsqueeze_(-1) + nears, fars = near_far_from_bound(rays_o, rays_d, self.bound, type='sphere', min_near=self.min_near) + + # random sample light_d if not provided + if light_d is None: + # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) + light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float)) + light_d = safe_normalize(light_d) + + #print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') + + z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T] + z_vals = z_vals.expand((N, num_steps)) # [N, T] + z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] + + # perturb z_vals + sample_dist = (fars - nears) / num_steps + if perturb: + z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist + #z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. + + # generate xyzs + xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] + xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. + + #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) + + # query SDF and RGB + density_outputs = self.density(xyzs.reshape(-1, 3)) + + #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] + for k, v in density_outputs.items(): + density_outputs[k] = v.view(N, num_steps, -1) + + # upsample z_vals (nerf-like) + if upsample_steps > 0: + with torch.no_grad(): + + deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] + deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) + + alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T] + alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1] + weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T] + + # sample new z_vals + z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1] + new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t] + + new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] + new_xyzs = torch.min(torch.max(new_xyzs, aabb[:3]), aabb[3:]) # a manual clip. + + # only forward new points to save computation + new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) + #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] + for k, v in new_density_outputs.items(): + new_density_outputs[k] = v.view(N, upsample_steps, -1) + + # re-order + z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] + z_vals, z_index = torch.sort(z_vals, dim=1) + + xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] + xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)) + + for k in density_outputs: + tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1) + density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)) + + deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] + deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) + alphas = 1 - torch.exp(-deltas * density_outputs['sigma'].squeeze(-1)) # [N, T+t] + alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1] + weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] + + dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) + for k, v in density_outputs.items(): + density_outputs[k] = v.view(-1, v.shape[-1]) + + sigmas, rgbs, normals = self(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), light_d, ratio=ambient_ratio, shading=shading) + rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] + + if normals is not None: + # orientation loss + normals = normals.view(N, -1, 3) + loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 + results['loss_orient'] = loss_orient.sum(-1).mean() + + # surface normal smoothness + normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2).view(N, -1, 3) + loss_smooth = (normals - normals_perturb).abs() + results['loss_smooth'] = loss_smooth.mean() + + # calculate weight_sum (mask) + weights_sum = weights.sum(dim=-1) # [N] + + # calculate depth + ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) + depth = torch.sum(weights * ori_z_vals, dim=-1) + + # calculate color + image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] + + # mix background color + if self.bg_radius > 0: + # use the bg model to calculate bg_color + # sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] + bg_color = self.background(rays_d.reshape(-1, 3)) # [N, 3] + elif bg_color is None: + bg_color = 1 + + # image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + + image = image.view(*prefix, 3) + depth = depth.view(*prefix) + + mask = (nears < fars).reshape(*prefix) + + results['image'] = image + results['depth'] = depth + results['weights_sum'] = weights_sum + results['mask'] = mask + + return results + + + def run_cuda(self, rays_o, rays_d, dt_gamma=0, light_d=None, ambient_ratio=1.0, shading='albedo', bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # return: image: [B, N, 3], depth: [B, N] + # print("run cuda.........................") + # print(rays_d.shape, rays_o.shape) + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # N = B * N, in fact + device = rays_o.device + + # pre-calculate near far + nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer) + + # random sample light_d if not provided + if light_d is None: + # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) + light_d = (rays_o[0] + torch.randn(3, device=device, dtype=torch.float)) + light_d = safe_normalize(light_d) + + results = {} + + if self.training: + # setup counter + counter = self.step_counter[self.local_step % 16] + counter.zero_() # set to 0 + self.local_step += 1 + + xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps) + + #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) + + sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) + + #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') + + weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh) + + # normals related regularizations + if normals is not None: + # orientation loss (not very exact in cuda ray mode) + weights = 1 - torch.exp(-sigmas) + loss_orient = weights.detach() * (normals * dirs).sum(-1).clamp(min=0) ** 2 + results['loss_orient'] = loss_orient.mean() + + # surface normal smoothness + normals_perturb = self.normal(xyzs + torch.randn_like(xyzs) * 1e-2) + loss_smooth = (normals - normals_perturb).abs() + results['loss_smooth'] = loss_smooth.mean() + + else: + # print("data preparation") + # allocate outputs + dtype = torch.float32 + + weights_sum = torch.zeros(N, dtype=dtype, device=device) + depth = torch.zeros(N, dtype=dtype, device=device) + image = torch.zeros(N, 3, dtype=dtype, device=device) + + n_alive = N + rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N] + rays_t = nears.clone() # [N] + + step = 0 + + while step < max_steps: # hard coded max step + # print("step") + # count alive rays + n_alive = rays_alive.shape[0] + + # exit loop + if n_alive <= 0: + break + + # decide compact_steps + n_step = max(min(N // n_alive, 8), 1) + + xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps) + # print("ray marching") + # print(xyzs.shape) + sigmas, rgbs, normals = self(xyzs, dirs, light_d, ratio=ambient_ratio, shading=shading) + + raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh) + # print("composite rays") + rays_alive = rays_alive[rays_alive >= 0] + #print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') + + step += n_step + # print("after step") + # mix background color + if self.bg_radius > 0: + + # use the bg model to calculate bg_color + # sph = raymarching.sph_from_ray(rays_o, rays_d, self.bg_radius) # [N, 2] in [-1, 1] + bg_color = self.background(rays_d) # [N, 3] + + elif bg_color is None: + bg_color = 1 + + # image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + image = image.view(*prefix, 3) + + depth = torch.clamp(depth - nears, min=0) / (fars - nears) + depth = depth.view(*prefix) + + weights_sum = weights_sum.reshape(*prefix) + + mask = (nears < fars).reshape(*prefix) + + results['image'] = image + results['depth'] = depth + results['weights_sum'] = weights_sum + results['mask'] = mask + + # print("results................") + + return results + + + @torch.no_grad() + def update_extra_state(self, decay=0.95, S=128): + # call before each epoch to update extra states. + + if not self.cuda_ray: + return + + ### update density grid + tmp_grid = - torch.ones_like(self.density_grid) + + X = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S) + Y = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S) + Z = torch.arange(self.grid_size, dtype=torch.int32, device=self.aabb_train.device).split(S) + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = custom_meshgrid(xs, ys, zs) + coords = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + xyzs = 2 * coords.float() / (self.grid_size - 1) - 1 # [N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_size = bound / self.grid_size + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_size) + # add noise in [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size + # query density + sigmas = self.density(cas_xyzs)['sigma'].reshape(-1).detach() + # assign + tmp_grid[cas, indices] = sigmas + + # ema update + valid_mask = self.density_grid >= 0 + self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) + self.mean_density = torch.mean(self.density_grid[valid_mask]).item() + self.iter_density += 1 + + # convert to bitfield + density_thresh = min(self.mean_density, self.density_thresh) + self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield) + + ### update step counter + total_step = min(16, self.local_step) + if total_step > 0: + self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) + self.local_step = 0 + + # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > density_thresh).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') + + + def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs): + # rays_o, rays_d: [B, N, 3], assumes B == 1 + # return: pred_rgb: [B, N, 3] + + if self.cuda_ray: + _run = self.run_cuda + else: + _run = self.run + + B, N = rays_o.shape[:2] + device = rays_o.device + + # # never stage when cuda_ray + # if staged and not self.cuda_ray: + # depth = torch.empty((B, N), device=device) + # image = torch.empty((B, N, 3), device=device) + # weights_sum = torch.empty((B, N), device=device) + + # for b in range(B): + # head = 0 + # while head < N: + # tail = min(head + max_ray_batch, N) + # results_ = _run(rays_o[b:b+1, head:tail], rays_d[b:b+1, head:tail], **kwargs) + # depth[b:b+1, head:tail] = results_['depth'] + # weights_sum[b:b+1, head:tail] = results_['weights_sum'] + # image[b:b+1, head:tail] = results_['image'] + # head += max_ray_batch + + # results = {} + # results['depth'] = depth + # results['image'] = image + # results['weights_sum'] = weights_sum + + # else: + results = _run(rays_o, rays_d, **kwargs) + + return results \ No newline at end of file diff --git a/nerf/sd.py b/nerf/sd.py new file mode 100755 index 0000000..4a60953 --- /dev/null +++ b/nerf/sd.py @@ -0,0 +1,220 @@ +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler + +# suppress partial model loading warning +logging.set_verbosity_error() + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import time + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = True + +class StableDiffusion(nn.Module): + def __init__(self, device, sd_version='2.0', hf_key=None): + super().__init__() + + self.device = device + self.sd_version = sd_version + + print(f'[INFO] loading stable diffusion...') + + if hf_key is not None: + print(f'[INFO] using hugging face custom model key: {hf_key}') + model_key = hf_key + elif self.sd_version == '2.0': + model_key = "stabilityai/stable-diffusion-2-base" + elif self.sd_version == '1.5': + model_key = "runwayml/stable-diffusion-v1-5" + else: + raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.') + + # Create model + self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae").to(self.device) + self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer") + self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder").to(self.device) + self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet").to(self.device) + + self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler") + # self.scheduler = PNDMScheduler.from_pretrained(model_key, subfolder="scheduler") + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.min_step = int(self.num_train_timesteps * 0.02) + self.max_step = int(self.num_train_timesteps * 0.98) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + + print(f'[INFO] loaded stable diffusion!') + + def get_text_embeds(self, prompt, negative_prompt): + # prompt, negative_prompt: [str] + + # Tokenize text and get embeddings + text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt') + + with torch.no_grad(): + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + + # Do the same for unconditional embeddings + uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt') + + with torch.no_grad(): + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # Cat for final embeddings + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + return text_embeddings + + + def train_step(self, text_embeddings, pred_rgb, guidance_scale=100): + + # interp to 512x512 to be fed into vae. + + # _t = time.time() + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # torch.cuda.synchronize(); print(f'[TIME] guiding: interp {time.time() - _t:.4f}s') + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint(self.min_step, self.max_step + 1, [1], dtype=torch.long, device=self.device) + + # encode image into latents with vae, requires grad! + # _t = time.time() + latents = self.encode_imgs(pred_rgb_512) + # torch.cuda.synchronize(); print(f'[TIME] guiding: vae enc {time.time() - _t:.4f}s') + + # predict the noise residual with unet, NO grad! + # _t = time.time() + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + # torch.cuda.synchronize(); print(f'[TIME] guiding: unet {time.time() - _t:.4f}s') + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # w(t), sigma_t^2 + w = (1 - self.alphas[t]) + # w = self.alphas[t] ** 0.5 * (1 - self.alphas[t]) + grad = w * (noise_pred - noise) + + # clip grad for stable training? + # grad = grad.clamp(-10, 10) + grad = torch.nan_to_num(grad) + + # manually backward, since we omitted an item in grad and cannot simply autodiff. + # _t = time.time() + latents.backward(gradient=grad, retain_graph=True) + # torch.cuda.synchronize(); print(f'[TIME] guiding: backward {time.time() - _t:.4f}s') + + return 0 # dummy loss value + + def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): + + if latents is None: + latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device) + + self.scheduler.set_timesteps(num_inference_steps) + + with torch.autocast('cuda'): + for i, t in enumerate(self.scheduler.timesteps): + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. + latent_model_input = torch.cat([latents] * 2) + + # predict the noise residual + with torch.no_grad(): + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample'] + + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents)['prev_sample'] + + return latents + + def decode_latents(self, latents): + + latents = 1 / 0.18215 * latents + + with torch.no_grad(): + imgs = self.vae.decode(latents).sample + + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def encode_imgs(self, imgs): + # imgs: [B, 3, H, W] + + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs).latent_dist + latents = posterior.sample() * 0.18215 + + return latents + + def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None): + + if isinstance(prompts, str): + prompts = [prompts] + + if isinstance(negative_prompts, str): + negative_prompts = [negative_prompts] + + # Prompts -> text embeds + text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2, 77, 768] + + # Text embeds -> img latents + latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64] + + # Img latents -> imgs + imgs = self.decode_latents(latents) # [1, 3, 512, 512] + + # Img to Numpy + imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy() + imgs = (imgs * 255).round().astype('uint8') + + return imgs + + +if __name__ == '__main__': + + import argparse + import matplotlib.pyplot as plt + + parser = argparse.ArgumentParser() + parser.add_argument('prompt', type=str) + parser.add_argument('--negative', default='', type=str) + parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0'], help="stable diffusion version") + parser.add_argument('-H', type=int, default=512) + parser.add_argument('-W', type=int, default=512) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--steps', type=int, default=50) + opt = parser.parse_args() + + seed_everything(opt.seed) + + device = torch.device('cuda') + + sd = StableDiffusion(device, opt.sd_version) + + imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps) + + # visualize image + plt.imshow(imgs[0]) + plt.show() + + + + diff --git a/nerf/utils.py b/nerf/utils.py new file mode 100755 index 0000000..d2ff015 --- /dev/null +++ b/nerf/utils.py @@ -0,0 +1,1087 @@ +import os +import glob +import tqdm +import math +import imageio +import random +import warnings +import tensorboardX + +import numpy as np +import pandas as pd + +import time +from datetime import datetime + +import cv2 +import matplotlib.pyplot as plt + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader + +import trimesh +from rich.console import Console +from torch_ema import ExponentialMovingAverage + +from packaging import version as pver + + +def custom_meshgrid(*args): + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + + +def safe_normalize(x, eps=1e-20): + return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) + + +@torch.cuda.amp.autocast(enabled=False) +def get_rays(poses, intrinsics, H, W, N=-1, error_map=None): + ''' get rays + Args: + poses: [B, 4, 4], cam2world + intrinsics: [4] + H, W, N: int + error_map: [B, 128 * 128], sample probability based on training error + Returns: + rays_o, rays_d: [B, N, 3] + inds: [B, N] + ''' + + device = poses.device + B = poses.shape[0] + fx, fy, cx, cy = intrinsics + + i, j = custom_meshgrid(torch.linspace( + 0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) + i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 + j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5 + + results = {} + + if N > 0: + N = min(N, H*W) + + if error_map is None: + inds = torch.randint( + 0, H*W, size=[N], device=device) # may duplicate + inds = inds.expand([B, N]) + else: + + # weighted sample on a low-reso grid + # [B, N], but in [0, 128*128) + inds_coarse = torch.multinomial( + error_map.to(device), N, replacement=False) + + # map to the original resolution with random perturb. + # `//` will throw a warning in torch 1.10... anyway. + inds_x, inds_y = inds_coarse // 128, inds_coarse % 128 + sx, sy = H / 128, W / 128 + inds_x = (inds_x * sx + torch.rand(B, N, device=device) + * sx).long().clamp(max=H - 1) + inds_y = (inds_y * sy + torch.rand(B, N, device=device) + * sy).long().clamp(max=W - 1) + inds = inds_x * W + inds_y + + # need this when updating error_map + results['inds_coarse'] = inds_coarse + + i = torch.gather(i, -1, inds) + j = torch.gather(j, -1, inds) + + results['inds'] = inds + + else: + inds = torch.arange(H*W, device=device).expand([B, H*W]) + + zs = torch.ones_like(i) + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + directions = torch.stack((xs, ys, zs), dim=-1) + directions = safe_normalize(directions) + rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) + + rays_o = poses[..., :3, 3] # [B, 3] + rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] + + results['rays_o'] = rays_o + results['rays_d'] = rays_d + + return results + + +def get_ndc_rays(H, W, focal, near, rays_o, rays_d): + """ + Transform rays from world coordinate to NDC. + NDC: Space such that the canvas is a cube with sides [-1, 1] in each axis. + For detailed derivation, please see: + http://www.songho.ca/opengl/gl_projectionmatrix.html + https://github.com/bmild/nerf/files/4451808/ndc_derivation.pdf + + In practice, use NDC "if and only if" the scene is unbounded (has a large depth). + See https://github.com/bmild/nerf/issues/18 + + Inputs: + H, W, focal: image height, width and focal length + near: (N_rays) or float, the depths of the near plane + rays_o: (N_rays, 3), the origin of the rays in world coordinate + rays_d: (N_rays, 3), the direction of the rays in world coordinate + + Outputs: + rays_o: (N_rays, 3), the origin of the rays in NDC + rays_d: (N_rays, 3), the direction of the rays in NDC + """ + # Shift ray origins to near plane + t = -(near + rays_o[..., 2]) / rays_d[..., 2] + rays_o = rays_o + t[..., None] * rays_d + + # Store some intermediate homogeneous results + ox_oz = rays_o[..., 0] / rays_o[..., 2] + oy_oz = rays_o[..., 1] / rays_o[..., 2] + + # Projection + o0 = -1./(W/(2.*focal)) * ox_oz + o1 = -1./(H/(2.*focal)) * oy_oz + o2 = 1. + 2. * near / rays_o[..., 2] + + d0 = -1./(W/(2.*focal)) * (rays_d[..., 0]/rays_d[..., 2] - ox_oz) + d1 = -1./(H/(2.*focal)) * (rays_d[..., 1]/rays_d[..., 2] - oy_oz) + d2 = 1 - o2 + + rays_o = torch.stack([o0, o1, o2], -1) # (B, 3) + rays_d = torch.stack([d0, d1, d2], -1) # (B, 3) + + return rays_o, rays_d + + +def seed_everything(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + # torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.benchmark = True + + +def torch_vis_2d(x, renormalize=False): + # x: [3, H, W] or [1, H, W] or [H, W] + import matplotlib.pyplot as plt + import numpy as np + import torch + + if isinstance(x, torch.Tensor): + if len(x.shape) == 3: + x = x.permute(1, 2, 0).squeeze() + x = x.detach().cpu().numpy() + + print(f'[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}') + + x = x.astype(np.float32) + + # renormalize + if renormalize: + x = (x - x.min(axis=0, keepdims=True)) / \ + (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8) + + plt.imshow(x) + plt.show() + + +@torch.jit.script +def linear_to_srgb(x): + return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055) + + +@torch.jit.script +def srgb_to_linear(x): + return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) + + +class MSELoss(nn.Module): + def __init__(self): + super(MSELoss, self).__init__() + self.loss = nn.MSELoss(reduction='mean') + + # TODO: modify this to feat + def forward(self, inputs, targets): + + rgb_loss = self.loss(inputs, targets) + + return rgb_loss + + +class Trainer(object): + def __init__(self, + name, # name of this experiment + opt, # extra conf + model, # network + guidance, # guidance network + clip_guidance, + criterion=None, # loss function, if None, assume inline implementation in train_step + optimizer=None, # optimizer + ema_decay=None, # if use EMA, set the decay + lr_scheduler=None, # scheduler + # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. + metrics=[], + local_rank=0, # which GPU am I + world_size=1, # total num of GPUs + # device to use, usually setting to None is OK. (auto choose device) + device=None, + mute=False, # whether to mute all print + fp16=False, # amp optimize level + eval_interval=1, # eval once every $ epoch + max_keep_ckpt=2, # max num of saved ckpts in disk + workspace='workspace', # workspace to save logs & ckpts + best_mode='min', # the smaller/larger result, the better + use_loss_as_metric=True, # use loss as the first metric + report_metric_at_train=False, # also report metrics at training + use_checkpoint="latest", # which ckpt to use at init time + use_tensorboardX=True, # whether to use tensorboard for logging + # whether to call scheduler.step() after every train step + scheduler_update_every_step=False, + pretrained=True, # True for modeifed stage, False for pre-training stage + ): + + self.name = name + self.opt = opt + self.mute = mute + self.metrics = metrics + self.local_rank = local_rank + self.world_size = world_size + self.workspace = workspace + self.ema_decay = ema_decay + self.fp16 = fp16 + self.best_mode = best_mode + self.use_loss_as_metric = use_loss_as_metric + self.report_metric_at_train = report_metric_at_train + self.max_keep_ckpt = max_keep_ckpt + self.eval_interval = eval_interval + self.use_checkpoint = use_checkpoint + self.use_tensorboardX = use_tensorboardX + self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") + self.scheduler_update_every_step = scheduler_update_every_step + self.device = device if device is not None else torch.device( + f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') + self.console = Console() + + # loss + self.rgb_loss = MSELoss() + self.pretrained = pretrained + + model.to(self.device) + if self.world_size > 1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[local_rank]) + self.model = model + + # guide model + self.guidance = guidance + self.clip_guidance = clip_guidance + + # text prompt + if self.guidance is not None: + + for p in self.guidance.parameters(): + p.requires_grad = False + + for p in self.clip_guidance.parameters(): + p.requires_grad = False + + self.prepare_text_embeddings() + + else: + self.text_z = None + self.text_bg_z = None + + if isinstance(criterion, nn.Module): + criterion.to(self.device) + self.criterion = criterion + + if optimizer is None: + self.optimizer = optim.Adam( + self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam + else: + self.optimizer = optimizer(self.model) + + if lr_scheduler is None: + self.lr_scheduler = optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler + else: + self.lr_scheduler = lr_scheduler(self.optimizer) + + if ema_decay is not None: + self.ema = ExponentialMovingAverage( + self.model.parameters(), decay=ema_decay) + else: + self.ema = None + + self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) + + # variable init + self.epoch = 0 + self.global_step = 0 + self.local_step = 0 + self.stats = { + "loss": [], + "valid_loss": [], + "results": [], # metrics[0], or valid_loss + "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt + "best_result": None, + } + + # auto fix + if len(metrics) == 0 or self.use_loss_as_metric: + self.best_mode = 'min' + + # workspace prepare + self.log_ptr = None + if self.workspace is not None: + os.makedirs(self.workspace, exist_ok=True) + self.log_path = os.path.join(workspace, f"log_{self.name}.txt") + self.log_ptr = open(self.log_path, "a+") + + self.ckpt_path = os.path.join(self.workspace, 'checkpoints') + self.best_path = f"{self.ckpt_path}/{self.name}.pth" + os.makedirs(self.ckpt_path, exist_ok=True) + self.log(opt) + self.log( + f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}') + self.log( + f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}') + + if self.workspace is not None: + if self.use_checkpoint == "scratch": + self.log("[INFO] Training from scratch ...") + elif self.use_checkpoint == "latest": + self.log("[INFO] Loading latest checkpoint ...") + self.load_checkpoint() + elif self.use_checkpoint == "latest_model": + self.log("[INFO] Loading latest checkpoint (model only)...") + self.load_checkpoint(model_only=True) + elif self.use_checkpoint == "best": + if os.path.exists(self.best_path): + self.log("[INFO] Loading best checkpoint ...") + self.load_checkpoint(self.best_path) + else: + self.log( + f"[INFO] {self.best_path} not found, loading latest ...") + self.load_checkpoint() + else: # path to ckpt + self.log(f"[INFO] Loading {self.use_checkpoint} ...") + self.load_checkpoint(self.use_checkpoint) + + # calculate the text embs. + def prepare_text_embeddings(self): + + if self.opt.text is None: + self.log(f"[WARN] text prompt is not provided.") + self.text_z = None + return + + if not self.opt.dir_text: + self.text_z = self.guidance.get_text_embeds( + [self.opt.text], [self.opt.negative]) + else: + self.text_z = [] + for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']: + # construct dir-encoded text + text = f"{self.opt.text}, {d} view" + + negative_text = f"{self.opt.negative}" + + # explicit negative dir-encoded text + if self.opt.suppress_face: + if negative_text != '': + negative_text += ', ' + + if d == 'back': + negative_text += "face" + # elif d == 'front': negative_text += "" + elif d == 'side': + negative_text += "face" + elif d == 'overhead': + negative_text += "face" + elif d == 'bottom': + negative_text += "face" + + text_z = self.guidance.get_text_embeds([text], [negative_text]) + self.text_z.append(text_z) + self.text_bg_z = self.clip_guidance.get_text_embeds(self.opt.text_bg) + + def __del__(self): + if self.log_ptr: + self.log_ptr.close() + + def log(self, *args, **kwargs): + if self.local_rank == 0: + if not self.mute: + # print(*args) + self.console.print(*args, **kwargs) + if self.log_ptr: + print(*args, file=self.log_ptr) + self.log_ptr.flush() # write immediately to file + + # ------------------------------ + + def train_step(self, data): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + + B, N = rays_o.shape[:2] + H, W = data['H'], data['W'] + + shading = 'albedo' + ambient_ratio = 1.0 + # _t = time.time() + # bg_color = torch.rand((B * N, 3), device=rays_o.device) + outputs = self.model.render(rays_o, rays_d, staged=False, perturb=True, bg_color=None, + ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt)) + pred_rgb = outputs['image'].reshape(B, H, W, 3).permute( + 0, 3, 1, 2).contiguous() # [1, 3, H, W] + + # torch.cuda.synchronize(); print(f'[TIME] nerf render {time.time() - _t:.4f}s') + gt_rgb = data['rgbs'].cuda().reshape( + B, H, W, 3).permute(0, 3, 1, 2).contiguous() + if self.pretrained: + + mask = data['mask'].cuda().reshape(B, H, W, 1).permute( + 0, 3, 1, 2).contiguous() + mask_gt_rgb = gt_rgb * (mask < 0.5) + mask_pred_rgb = pred_rgb * (mask < 0.5) + + # clip loss for black hole + # background text + bg_feature = self.clip_guidance.encode_img(mask_pred_rgb) + bg_feature = bg_feature.clone() / bg_feature.norm(dim=-1, keepdim=True) + text_feature = self.text_bg_z / \ + self.text_bg_z.norm(dim=-1, keepdim=True) + similarity = text_feature @ bg_feature.T + sim_loss = 1e-2 * -similarity[0][0] + + if self.opt.dir_text: + dirs = data['dir'] + text_z = self.text_z[dirs] + else: + text_z = self.text_z + + # encode pred_rgb to latents + # _t = time.time() + loss = self.guidance.train_step(text_z, pred_rgb) + loss += 100 * self.rgb_loss(mask_pred_rgb, mask_gt_rgb) + loss += sim_loss + else: + + loss = self.rgb_loss(pred_rgb, gt_rgb) + + pred_ws = outputs['weights_sum'].reshape(B, 1, H, W) + # pred_ws = pred_ws * (mask > 0.5) # //TODO + + return pred_rgb, pred_ws, loss + + def eval_step(self, data): + + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + + B, N = rays_o.shape[:2] + H, W = data['H'], data['W'] + + shading = data['shading'] if 'shading' in data else 'albedo' + ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0 + light_d = data['light_d'] if 'light_d' in data else None + + outputs = self.model.render(rays_o, rays_d, staged=True, perturb=False, bg_color=None, light_d=light_d, + ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, **vars(self.opt)) + pred_rgb = outputs['image'].reshape(B, H, W, 3) + pred_depth = outputs['depth'].reshape(B, H, W) + pred_ws = outputs['weights_sum'].reshape(B, H, W) + # mask_ws = outputs['mask'].reshape(B, H, W) # near < far + + # loss_ws = pred_ws.sum() / mask_ws.sum() + # loss_ws = pred_ws.mean() + + alphas = (pred_ws).clamp(1e-5, 1 - 1e-5) + # alphas = alphas ** 2 # skewed entropy, favors 0 over 1 + loss_entropy = (- alphas * torch.log2(alphas) - + (1 - alphas) * torch.log2(1 - alphas)).mean() + + loss = self.opt.lambda_entropy * loss_entropy + + return pred_rgb, pred_depth, loss + + def test_step(self, data, bg_color=None, perturb=False): + # print("test_step ing...............") + rays_o = data['rays_o'] # [B, N, 3] + rays_d = data['rays_d'] # [B, N, 3] + + B, N = rays_o.shape[:2] + H, W = data['H'], data['W'] + + if bg_color is not None: + bg_color = bg_color.to(rays_o.device) + else: + bg_color = torch.ones(3, device=rays_o.device) # [3] + + shading = data['shading'] if 'shading' in data else 'albedo' + ambient_ratio = data['ambient_ratio'] if 'ambient_ratio' in data else 1.0 + light_d = data['light_d'] if 'light_d' in data else None + # print("test_step befor rendering...............") + outputs = self.model.render(rays_o, rays_d, staged=True, perturb=perturb, light_d=light_d, + ambient_ratio=ambient_ratio, shading=shading, force_all_rays=True, bg_color=bg_color, **vars(self.opt)) + + pred_rgb = outputs['image'].reshape(B, H, W, 3) + pred_depth = outputs['depth'].reshape(B, H, W) + + return pred_rgb, pred_depth + + def save_mesh(self, save_path=None, resolution=128): + + if save_path is None: + save_path = os.path.join(self.workspace, 'mesh') + + self.log(f"==> Saving mesh to {save_path}") + + os.makedirs(save_path, exist_ok=True) + + self.model.export_mesh(save_path, resolution=resolution) + + self.log(f"==> Finished saving mesh.") + + # ------------------------------ + + def train(self, train_loader, valid_loader, max_epochs): + # print(self.text_z) + # assert self.text_z is not None, 'Training must provide a text prompt!' + + if self.use_tensorboardX and self.local_rank == 0: + self.writer = tensorboardX.SummaryWriter( + os.path.join(self.workspace, "run", self.name)) + + start_t = time.time() + + for epoch in range(self.epoch + 1, max_epochs + 1): + self.epoch = epoch + + self.train_one_epoch(train_loader) + + if self.workspace is not None and self.local_rank == 0: + self.save_checkpoint(full=True, best=False) + + if self.epoch % self.eval_interval == 0: + self.evaluate_one_epoch(valid_loader) + self.save_checkpoint(full=False, best=True) + + end_t = time.time() + + self.log(f"[INFO] training takes {(end_t - start_t)/ 60:.4f} minutes.") + + if self.use_tensorboardX and self.local_rank == 0: + self.writer.close() + + def evaluate(self, loader, name=None): + self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX + self.evaluate_one_epoch(loader, name) + self.use_tensorboardX = use_tensorboardX + + def test(self, loader, save_path=None, name=None, write_video=False): + # print("test.....................") + if save_path is None: + save_path = os.path.join(self.workspace, 'results') + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + os.makedirs(save_path, exist_ok=True) + + self.log(f"==> Start Test, save results to {save_path}") + + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, + bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + self.model.eval() + + if write_video: + all_preds = [] + all_preds_depth = [] + + with torch.no_grad(): + # print("before test step") + for i, data in enumerate(loader): + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, preds_depth = self.test_step(data) + # print("after test step", i) + pred = preds[0].detach().cpu().numpy() + pred = (pred * 255).astype(np.uint8) + + pred_depth = preds_depth[0].detach().cpu().numpy() + pred_depth = (pred_depth * 255).astype(np.uint8) + + if write_video: + all_preds.append(pred) + all_preds_depth.append(pred_depth) + else: + cv2.imwrite(os.path.join(save_path, f'{name}_{i:04d}_rgb.png'), cv2.cvtColor( + pred, cv2.COLOR_RGB2BGR)) + cv2.imwrite(os.path.join( + save_path, f'{name}_{i:04d}_depth.png'), pred_depth) + + pbar.update(loader.batch_size) + + if write_video: + all_preds = np.stack(all_preds, axis=0) + all_preds_depth = np.stack(all_preds_depth, axis=0) + + imageio.mimwrite(os.path.join( + save_path, f'{name}_rgb.mp4'), all_preds, fps=25, quality=8, macro_block_size=1) + imageio.mimwrite(os.path.join( + save_path, f'{name}_depth.mp4'), all_preds_depth, fps=25, quality=8, macro_block_size=1) + + self.log(f"==> Finished Test.") + + # [GUI] train text step. + def train_gui(self, train_loader, step=16): + + self.model.train() + + total_loss = torch.tensor([0], dtype=torch.float32, device=self.device) + + loader = iter(train_loader) + + for _ in range(step): + + # mimic an infinite loop dataloader (in case the total dataset is smaller than step) + try: + data = next(loader) + except StopIteration: + loader = iter(train_loader) + data = next(loader) + + # update grid every 16 steps + if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() + + self.global_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + pred_rgbs, pred_ws, loss = self.train_step(data) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + total_loss += loss.detach() + + if self.ema is not None: + self.ema.update() + + average_loss = total_loss.item() / step + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + outputs = { + 'loss': average_loss, + 'lr': self.optimizer.param_groups[0]['lr'], + } + + return outputs + + # [GUI] test on a single image + + def test_gui(self, pose, intrinsics, W, H, bg_color=None, spp=1, downscale=1, light_d=None, ambient_ratio=1.0, shading='albedo'): + + # render resolution (may need downscale to for better frame rate) + rH = int(H * downscale) + rW = int(W * downscale) + intrinsics = intrinsics * downscale + # print(rW, rH) + pose = torch.from_numpy(pose).unsqueeze(0).to(self.device) + + rays = get_rays(pose, intrinsics, rH, rW, -1) + # rays_o, rays_d = get_ndc_rays(rH, rW, ) + + # from degree theta/phi to 3D normalized vec + light_d = np.deg2rad(light_d) + light_d = np.array([ + np.sin(light_d[0]) * np.sin(light_d[1]), + np.cos(light_d[0]), + np.sin(light_d[0]) * np.cos(light_d[1]), + ], dtype=np.float32) + light_d = torch.from_numpy(light_d).to(self.device) + + data = { + 'rays_o': rays['rays_o'], + 'rays_d': rays['rays_d'], + 'H': rH, + 'W': rW, + 'light_d': light_d, + 'ambient_ratio': ambient_ratio, + 'shading': shading, + } + + self.model.eval() + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=self.fp16): + # here spp is used as perturb random seed! + preds, preds_depth = self.test_step( + data, bg_color=bg_color, perturb=False if spp == 1 else spp) + + if self.ema is not None: + self.ema.restore() + + # interpolation to the original resolution + if downscale != 1: + # have to permute twice with torch... + preds = F.interpolate(preds.permute(0, 3, 1, 2), size=( + H, W), mode='nearest').permute(0, 2, 3, 1).contiguous() + preds_depth = F.interpolate(preds_depth.unsqueeze( + 1), size=(H, W), mode='nearest').squeeze(1) + + outputs = { + 'image': preds[0].detach().cpu().numpy(), + 'depth': preds_depth[0].detach().cpu().numpy(), + } + + return outputs + + def train_one_epoch(self, loader): + self.log( + f"==> Start Training {self.workspace} Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...") + + total_loss = 0 + if self.local_rank == 0 and self.report_metric_at_train: + for metric in self.metrics: + metric.clear() + + self.model.train() + + # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs + # ref: https://pytorch.org/docs/stable/data.html + if self.world_size > 1: + loader.sampler.set_epoch(self.epoch) + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, + bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + self.local_step = 0 + + for data in loader: + + # update grid every 16 steps + if self.model.cuda_ray and self.global_step % self.opt.update_extra_interval == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() + + self.local_step += 1 + self.global_step += 1 + + self.optimizer.zero_grad() + + with torch.cuda.amp.autocast(enabled=self.fp16): + pred_rgbs, pred_ws, loss = self.train_step(data) + + self.scaler.scale(loss).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + + if self.scheduler_update_every_step: + self.lr_scheduler.step() + + loss_val = loss.item() + total_loss += loss_val + + if self.local_rank == 0: + # if self.report_metric_at_train: + # for metric in self.metrics: + # metric.update(preds, truths) + + if self.use_tensorboardX: + self.writer.add_scalar( + "train/loss", loss_val, self.global_step) + self.writer.add_scalar( + "train/lr", self.optimizer.param_groups[0]['lr'], self.global_step) + + if self.scheduler_update_every_step: + pbar.set_description( + f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}") + else: + pbar.set_description( + f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") + pbar.update(loader.batch_size) + + if self.ema is not None: + self.ema.update() + + average_loss = total_loss / self.local_step + self.stats["loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if self.report_metric_at_train: + for metric in self.metrics: + self.log(metric.report(), style="red") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="train") + metric.clear() + + if not self.scheduler_update_every_step: + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step(average_loss) + else: + self.lr_scheduler.step() + + self.log(f"==> Finished Epoch {self.epoch}.") + + def evaluate_one_epoch(self, loader, name=None): + self.log(f"++> Evaluate {self.workspace} at epoch {self.epoch} ...") + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + total_loss = 0 + if self.local_rank == 0: + for metric in self.metrics: + metric.clear() + + self.model.eval() + + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + if self.local_rank == 0: + pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, + bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') + + with torch.no_grad(): + self.local_step = 0 + + for data in loader: + self.local_step += 1 + + with torch.cuda.amp.autocast(enabled=self.fp16): + preds, preds_depth, loss = self.eval_step(data) + + # all_gather/reduce the statistics (NCCL only support all_*) + if self.world_size > 1: + dist.all_reduce(loss, op=dist.ReduceOp.SUM) + loss = loss / self.world_size + + preds_list = [torch.zeros_like(preds).to(self.device) for _ in range( + self.world_size)] # [[B, ...], [B, ...], ...] + dist.all_gather(preds_list, preds) + preds = torch.cat(preds_list, dim=0) + + preds_depth_list = [torch.zeros_like(preds_depth).to( + self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] + dist.all_gather(preds_depth_list, preds_depth) + preds_depth = torch.cat(preds_depth_list, dim=0) + + loss_val = loss.item() + total_loss += loss_val + + # only rank = 0 will perform evaluation. + if self.local_rank == 0: + + # save image + save_path = os.path.join( + self.workspace, 'validation', f'{name}_{self.local_step:04d}_rgb.png') + save_path_depth = os.path.join( + self.workspace, 'validation', f'{name}_{self.local_step:04d}_depth.png') + + # self.log(f"==> Saving validation image to {save_path}") + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + pred = preds[0].detach().cpu().numpy() + pred = (pred * 255).astype(np.uint8) + + pred_depth = preds_depth[0].detach().cpu().numpy() + pred_depth = (pred_depth * 255).astype(np.uint8) + + cv2.imwrite(save_path, cv2.cvtColor( + pred, cv2.COLOR_RGB2BGR)) + cv2.imwrite(save_path_depth, pred_depth) + + pbar.set_description( + f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") + pbar.update(loader.batch_size) + + average_loss = total_loss / self.local_step + self.stats["valid_loss"].append(average_loss) + + if self.local_rank == 0: + pbar.close() + if not self.use_loss_as_metric and len(self.metrics) > 0: + result = self.metrics[0].measure() + # if max mode, use -result + self.stats["results"].append( + result if self.best_mode == 'min' else - result) + else: + # if no metric, choose best by min loss + self.stats["results"].append(average_loss) + + for metric in self.metrics: + self.log(metric.report(), style="blue") + if self.use_tensorboardX: + metric.write(self.writer, self.epoch, prefix="evaluate") + metric.clear() + + if self.ema is not None: + self.ema.restore() + + self.log(f"++> Evaluate epoch {self.epoch} Finished.") + + def save_checkpoint(self, name=None, full=False, best=False): + + if name is None: + name = f'{self.name}_ep{self.epoch:04d}' + + state = { + 'epoch': self.epoch, + 'global_step': self.global_step, + 'stats': self.stats, + } + + if self.model.cuda_ray: + state['mean_count'] = self.model.mean_count + state['mean_density'] = self.model.mean_density + + if full: + state['optimizer'] = self.optimizer.state_dict() + state['lr_scheduler'] = self.lr_scheduler.state_dict() + state['scaler'] = self.scaler.state_dict() + if self.ema is not None: + state['ema'] = self.ema.state_dict() + + if not best: + + state['model'] = self.model.state_dict() + + file_path = f"{name}.pth" + + self.stats["checkpoints"].append(file_path) + + if len(self.stats["checkpoints"]) > self.max_keep_ckpt: + old_ckpt = os.path.join( + self.ckpt_path, self.stats["checkpoints"].pop(0)) + if os.path.exists(old_ckpt): + os.remove(old_ckpt) + + torch.save(state, os.path.join(self.ckpt_path, file_path)) + + else: + if len(self.stats["results"]) > 0: + # always save best since loss cannot reflect performance. + if True: + # self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}") + # self.stats["best_result"] = self.stats["results"][-1] + + # save ema results + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + state['model'] = self.model.state_dict() + + if self.ema is not None: + self.ema.restore() + + torch.save(state, self.best_path) + else: + self.log( + f"[WARN] no evaluated results found, skip saving best checkpoint.") + + def load_checkpoint(self, checkpoint=None, model_only=False): + if checkpoint is None: + checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/*.pth')) + if checkpoint_list: + checkpoint = checkpoint_list[-1] + self.log(f"[INFO] Latest checkpoint is {checkpoint}") + else: + self.log( + "[WARN] No checkpoint found, model randomly initialized.") + return + + checkpoint_dict = torch.load(checkpoint, map_location=self.device) + + if 'model' not in checkpoint_dict: + self.model.load_state_dict(checkpoint_dict) + self.log("[INFO] loaded model.") + return + + missing_keys, unexpected_keys = self.model.load_state_dict( + checkpoint_dict['model'], strict=False) + self.log("[INFO] loaded model.") + if len(missing_keys) > 0: + self.log(f"[WARN] missing keys: {missing_keys}") + if len(unexpected_keys) > 0: + self.log(f"[WARN] unexpected keys: {unexpected_keys}") + + if self.ema is not None and 'ema' in checkpoint_dict: + try: + self.ema.load_state_dict(checkpoint_dict['ema']) + self.log("[INFO] loaded EMA.") + except: + self.log("[WARN] failed to loaded EMA.") + + if self.model.cuda_ray: + if 'mean_count' in checkpoint_dict: + self.model.mean_count = checkpoint_dict['mean_count'] + if 'mean_density' in checkpoint_dict: + self.model.mean_density = checkpoint_dict['mean_density'] + + if model_only: + return + + self.stats = checkpoint_dict['stats'] + self.epoch = checkpoint_dict['epoch'] + self.global_step = checkpoint_dict['global_step'] + self.log( + f"[INFO] load at epoch {self.epoch}, global step {self.global_step}") + + if self.optimizer and 'optimizer' in checkpoint_dict: + try: + self.optimizer.load_state_dict(checkpoint_dict['optimizer']) + self.log("[INFO] loaded optimizer.") + except: + self.log("[WARN] Failed to load optimizer.") + + if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: + try: + self.lr_scheduler.load_state_dict( + checkpoint_dict['lr_scheduler']) + self.log("[INFO] loaded scheduler.") + except: + self.log("[WARN] Failed to load scheduler.") + + if self.scaler and 'scaler' in checkpoint_dict: + try: + self.scaler.load_state_dict(checkpoint_dict['scaler']) + self.log("[INFO] loaded scaler.") + except: + self.log("[WARN] Failed to load scaler.") diff --git a/optimizer.py b/optimizer.py new file mode 100755 index 0000000..f5bb64f --- /dev/null +++ b/optimizer.py @@ -0,0 +1,325 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + + +class Adan(Optimizer): + """ + Implements a pytorch variant of Adan + Adan was proposed in + Adan: Adaptive Nesterov Momentum Algorithm for + Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + Arguments: + params (iterable): iterable of parameters to optimize or + dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used for + first- and second-order moments. (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + no_prox (bool): how to perform the decoupled weight decay + (default: False) + foreach (bool): if True would use torch._foreach implementation. + It's faster but uses slightly more memory. (default: True) + """ + def __init__(self, + params, + lr=1e-3, + betas=(0.98, 0.92, 0.99), + eps=1e-8, + weight_decay=0.0, + max_grad_norm=0.0, + no_prox=False, + foreach: bool = True): + if not 0.0 <= max_grad_norm: + raise ValueError('Invalid Max grad norm: {}'.format(max_grad_norm)) + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= eps: + raise ValueError('Invalid epsilon value: {}'.format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError('Invalid beta parameter at index 2: {}'.format( + betas[2])) + defaults = dict(lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + max_grad_norm=max_grad_norm, + no_prox=no_prox, + foreach=foreach) + super().__init__(params, defaults) + + def __setstate__(self, state): + super(Adan, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('no_prox', False) + + @torch.no_grad() + def restart_opt(self): + for group in self.param_groups: + group['step'] = 0 + for p in group['params']: + if p.requires_grad: + state = self.state[p] + # State initialization + + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + # Exponential moving average of gradient difference + state['exp_avg_diff'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step.""" + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if self.defaults['max_grad_norm'] > 0: + device = self.param_groups[0]['params'][0].device + global_grad_norm = torch.zeros(1, device=device) + + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], + device=device) + for group in self.param_groups: + + for p in group['params']: + if p.grad is not None: + grad = p.grad + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + + clip_global_grad_norm = torch.clamp( + max_grad_norm / (global_grad_norm + group['eps']), + max=1.0).item() + else: + clip_global_grad_norm = 1.0 + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + exp_avg_diffs = [] + neg_pre_grads = [] + + beta1, beta2, beta3 = group['betas'] + # assume same step across group now to simplify things + # per parameter step can be easily support + # by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + bias_correction1 = 1.0 - beta1**group['step'] + bias_correction2 = 1.0 - beta2**group['step'] + bias_correction3 = 1.0 - beta3**group['step'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + grads.append(p.grad) + + state = self.state[p] + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + state['exp_avg_sq'] = torch.zeros_like(p) + state['exp_avg_diff'] = torch.zeros_like(p) + + if 'neg_pre_grad' not in state or group['step'] == 1: + state['neg_pre_grad'] = p.grad.clone().mul_( + -clip_global_grad_norm) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + exp_avg_diffs.append(state['exp_avg_diff']) + neg_pre_grads.append(state['neg_pre_grad']) + + kwargs = dict( + params=params_with_grad, + grads=grads, + exp_avgs=exp_avgs, + exp_avg_sqs=exp_avg_sqs, + exp_avg_diffs=exp_avg_diffs, + neg_pre_grads=neg_pre_grads, + beta1=beta1, + beta2=beta2, + beta3=beta3, + bias_correction1=bias_correction1, + bias_correction2=bias_correction2, + bias_correction3_sqrt=math.sqrt(bias_correction3), + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + no_prox=group['no_prox'], + clip_global_grad_norm=clip_global_grad_norm, + ) + + if group['foreach']: + _multi_tensor_adan(**kwargs) + else: + _single_tensor_adan(**kwargs) + + return loss + + +def _single_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + neg_pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + exp_avg_diff = exp_avg_diffs[i] + neg_grad_or_diff = neg_pre_grads[i] + + grad.mul_(clip_global_grad_norm) + + # for memory saving, we use `neg_grad_or_diff` + # to get some temp variable in a inplace way + neg_grad_or_diff.add_(grad) + + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t + exp_avg_diff.mul_(beta2).add_(neg_grad_or_diff, + alpha=1 - beta2) # diff_t + + neg_grad_or_diff.mul_(beta2).add_(grad) + exp_avg_sq.mul_(beta3).addcmul_(neg_grad_or_diff, + neg_grad_or_diff, + value=1 - beta3) # n_t + + denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps) + step_size_diff = lr * beta2 / bias_correction2 + step_size = lr / bias_correction1 + + if no_prox: + param.mul_(1 - lr * weight_decay) + param.addcdiv_(exp_avg, denom, value=-step_size) + param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) + else: + param.addcdiv_(exp_avg, denom, value=-step_size) + param.addcdiv_(exp_avg_diff, denom, value=-step_size_diff) + param.div_(1 + lr * weight_decay) + + neg_grad_or_diff.zero_().add_(grad, alpha=-1.0) + + +def _multi_tensor_adan( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + exp_avg_diffs: List[Tensor], + neg_pre_grads: List[Tensor], + *, + beta1: float, + beta2: float, + beta3: float, + bias_correction1: float, + bias_correction2: float, + bias_correction3_sqrt: float, + lr: float, + weight_decay: float, + eps: float, + no_prox: bool, + clip_global_grad_norm: Tensor, +): + if len(params) == 0: + return + + torch._foreach_mul_(grads, clip_global_grad_norm) + + # for memory saving, we use `neg_pre_grads` + # to get some temp variable in a inplace way + torch._foreach_add_(neg_pre_grads, grads) + + torch._foreach_mul_(exp_avgs, beta1) + torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t + + torch._foreach_mul_(exp_avg_diffs, beta2) + torch._foreach_add_(exp_avg_diffs, neg_pre_grads, + alpha=1 - beta2) # diff_t + + torch._foreach_mul_(neg_pre_grads, beta2) + torch._foreach_add_(neg_pre_grads, grads) + torch._foreach_mul_(exp_avg_sqs, beta3) + torch._foreach_addcmul_(exp_avg_sqs, + neg_pre_grads, + neg_pre_grads, + value=1 - beta3) # n_t + + denom = torch._foreach_sqrt(exp_avg_sqs) + torch._foreach_div_(denom, bias_correction3_sqrt) + torch._foreach_add_(denom, eps) + + step_size_diff = lr * beta2 / bias_correction2 + step_size = lr / bias_correction1 + + if no_prox: + torch._foreach_mul_(params, 1 - lr * weight_decay) + torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) + torch._foreach_addcdiv_(params, + exp_avg_diffs, + denom, + value=-step_size_diff) + else: + torch._foreach_addcdiv_(params, exp_avgs, denom, value=-step_size) + torch._foreach_addcdiv_(params, + exp_avg_diffs, + denom, + value=-step_size_diff) + torch._foreach_div_(params, 1 + lr * weight_decay) + torch._foreach_zero_(neg_pre_grads) + torch._foreach_add_(neg_pre_grads, grads, alpha=-1.0) \ No newline at end of file diff --git a/raymarching/__init__.py b/raymarching/__init__.py new file mode 100755 index 0000000..26d3cc6 --- /dev/null +++ b/raymarching/__init__.py @@ -0,0 +1 @@ +from .raymarching import * \ No newline at end of file diff --git a/raymarching/backend.py b/raymarching/backend.py new file mode 100755 index 0000000..a6a9a03 --- /dev/null +++ b/raymarching/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_raymarching', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/raymarching/raymarching.py b/raymarching/raymarching.py new file mode 100755 index 0000000..d1d9b6d --- /dev/null +++ b/raymarching/raymarching.py @@ -0,0 +1,385 @@ +import numpy as np +import time + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +# lazy building: +# `import raymarching` will not immediately build the extension, only if you actually call any functions. + +BACKEND = None + +def get_backend(): + global BACKEND + + if BACKEND is None: + try: + import _raymarching as _backend + except ImportError: + from .backend import _backend + + BACKEND = _backend + + return BACKEND + +# ---------------------------------------- +# utils +# ---------------------------------------- + +class _near_far_from_aabb(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): + ''' near_far_from_aabb, CUDA implementation + Calculate rays' intersection time (near and far) with aabb + Args: + rays_o: float, [N, 3] + rays_d: float, [N, 3] + aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) + min_near: float, scalar + Returns: + nears: float, [N] + fars: float, [N] + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + + get_backend().near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) + + return nears, fars + +near_far_from_aabb = _near_far_from_aabb.apply + + +class _sph_from_ray(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, radius): + ''' sph_from_ray, CUDA implementation + get spherical coordinate on the background sphere from rays. + Assume rays_o are inside the Sphere(radius). + Args: + rays_o: [N, 3] + rays_d: [N, 3] + radius: scalar, float + Return: + coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface) + ''' + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) + + get_backend().sph_from_ray(rays_o, rays_d, radius, N, coords) + + return coords + +sph_from_ray = _sph_from_ray.apply + + +class _morton3D(Function): + @staticmethod + def forward(ctx, coords): + ''' morton3D, CUDA implementation + Args: + coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) + TODO: check if the coord range is valid! (current 128 is safe) + Returns: + indices: [N], int32, in [0, 128^3) + + ''' + if not coords.is_cuda: coords = coords.cuda() + + N = coords.shape[0] + + indices = torch.empty(N, dtype=torch.int32, device=coords.device) + + get_backend().morton3D(coords.int(), N, indices) + + return indices + +morton3D = _morton3D.apply + +class _morton3D_invert(Function): + @staticmethod + def forward(ctx, indices): + ''' morton3D_invert, CUDA implementation + Args: + indices: [N], int32, in [0, 128^3) + Returns: + coords: [N, 3], int32, in [0, 128) + + ''' + if not indices.is_cuda: indices = indices.cuda() + + N = indices.shape[0] + + coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) + + get_backend().morton3D_invert(indices.int(), N, coords) + + return coords + +morton3D_invert = _morton3D_invert.apply + + +class _packbits(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid, thresh, bitfield=None): + ''' packbits, CUDA implementation + Pack up the density grid into a bit field to accelerate ray marching. + Args: + grid: float, [C, H * H * H], assume H % 2 == 0 + thresh: float, threshold + Returns: + bitfield: uint8, [C, H * H * H / 8] + ''' + if not grid.is_cuda: grid = grid.cuda() + grid = grid.contiguous() + + C = grid.shape[0] + H3 = grid.shape[1] + N = C * H3 // 8 + + if bitfield is None: + bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) + + get_backend().packbits(grid, N, thresh, bitfield) + + return bitfield + +packbits = _packbits.apply + +# ---------------------------------------- +# train functions +# ---------------------------------------- + +class _march_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024): + ''' march rays to generate points (forward only) + Args: + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + step_counter: int32, (2), used to count the actual number of generated points. + mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) + perturb: bool + align: int, pad output so its size is dividable by align, set to -1 to disable. + force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) + dirs: float, [M, 3], all generated points' view dirs. + deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth) + rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0] + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + density_bitfield = density_bitfield.contiguous() + + N = rays_o.shape[0] # num rays + M = N * max_steps # init max points number in total + + # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) + # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. + if not force_all_rays and mean_count > 0: + if align > 0: + mean_count += align - mean_count % align + M = mean_count + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) + rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps + + if step_counter is None: + step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter + + if perturb: + noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) + + get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number + + #print(step_counter, M) + + # only used at the first (few) epochs. + if force_all_rays or mean_count <= 0: + m = step_counter[0].item() # D2H copy + if align > 0: + m += align - m % align + xyzs = xyzs[:m] + dirs = dirs[:m] + deltas = deltas[:m] + + torch.cuda.empty_cache() + + return xyzs, dirs, deltas, rays + +march_rays_train = _march_rays_train.apply + + +class _composite_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4): + ''' composite rays' rgbs, according to the ray marching formula. + Args: + rgbs: float, [M, 3] + sigmas: float, [M,] + deltas: float, [M, 2] + rays: int32, [N, 3] + Returns: + weights_sum: float, [N,], the alpha channel + depth: float, [N, ], the Depth + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + + sigmas = sigmas.contiguous() + rgbs = rgbs.contiguous() + + M = sigmas.shape[0] + N = rays.shape[0] + + weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) + + get_backend().composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image) + + ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image) + ctx.dims = [M, N, T_thresh] + + return weights_sum, depth, image + + @staticmethod + @custom_bwd + def backward(ctx, grad_weights_sum, grad_depth, grad_image): + + # NOTE: grad_depth is not used now! It won't be propagated to sigmas. + + grad_weights_sum = grad_weights_sum.contiguous() + grad_image = grad_image.contiguous() + + sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors + M, N, T_thresh = ctx.dims + + grad_sigmas = torch.zeros_like(sigmas) + grad_rgbs = torch.zeros_like(rgbs) + + get_backend().composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs) + + return grad_sigmas, grad_rgbs, None, None, None + + +composite_rays_train = _composite_rays_train.apply + +# ---------------------------------------- +# infer functions +# ---------------------------------------- + +class _march_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024): + ''' march rays to generate points (forward only, for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) + rays_t: float, [N], the alive rays' time, we only use the first n_alive. + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + align: int, pad output so its size is dividable by align, set to -1 to disable. + perturb: bool/int, int > 0 is used as the random seed. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [n_alive * n_step, 3], all generated points' coords + dirs: float, [n_alive * n_step, 3], all generated points' view dirs. + deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). + ''' + + if not rays_o.is_cuda: rays_o = rays_o.cuda() + if not rays_d.is_cuda: rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + M = n_alive * n_step + + if align > 0: + M += align - (M % align) + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth + + if perturb: + # torch.manual_seed(perturb) # test_gui uses spp index as seed + noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device) + + get_backend().march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises) + + return xyzs, dirs, deltas + +march_rays = _march_rays.apply + + +class _composite_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float + def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2): + ''' composite rays' rgbs, according to the ray marching formula. (for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive) + rays_t: float, [N], the alive rays' time + sigmas: float, [n_alive * n_step,] + rgbs: float, [n_alive * n_step, 3] + deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). + In-place Outputs: + weights_sum: float, [N,], the alpha channel + depth: float, [N,], the depth value + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + get_backend().composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image) + return tuple() + + +composite_rays = _composite_rays.apply \ No newline at end of file diff --git a/raymarching/setup.py b/raymarching/setup.py new file mode 100755 index 0000000..d974499 --- /dev/null +++ b/raymarching/setup.py @@ -0,0 +1,62 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +''' +Usage: + +python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) + +python setup.py install # build extensions and install (copy) to PATH. +pip install . # ditto but better (e.g., dependency & metadata handling) + +python setup.py develop # build extensions and install (symbolic) to PATH. +pip install -e . # ditto but better (e.g., dependency & metadata handling) + +''' +setup( + name='raymarching', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_raymarching', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'raymarching.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/raymarching/src/bindings.cpp b/raymarching/src/bindings.cpp new file mode 100755 index 0000000..47920bc --- /dev/null +++ b/raymarching/src/bindings.cpp @@ -0,0 +1,19 @@ +#include + +#include "raymarching.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // utils + m.def("packbits", &packbits, "packbits (CUDA)"); + m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); + m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); + m.def("morton3D", &morton3D, "morton3D (CUDA)"); + m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); + // train + m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); + m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); + m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); + // infer + m.def("march_rays", &march_rays, "march rays (CUDA)"); + m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); +} \ No newline at end of file diff --git a/raymarching/src/raymarching.cu b/raymarching/src/raymarching.cu new file mode 100755 index 0000000..1606503 --- /dev/null +++ b/raymarching/src/raymarching.cu @@ -0,0 +1,914 @@ +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; } +inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; } +inline constexpr __device__ float PI() { return 3.141592653589793f; } +inline constexpr __device__ float RPI() { return 0.3183098861837907f; } + + +template +inline __host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +inline __host__ __device__ float signf(const float x) { + return copysignf(1.0, x); +} + +inline __host__ __device__ float clamp(const float x, const float min, const float max) { + return fminf(max, fmaxf(min, x)); +} + +inline __host__ __device__ void swapf(float& a, float& b) { + float c = a; a = b; b = c; +} + +inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { + const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z))); + int exponent; + frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ... + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) { + const float mx = dt * H * 0.5; + int exponent; + frexpf(mx, &exponent); + return fminf(max_cascade - 1, fmaxf(0, exponent)); +} + +inline __host__ __device__ uint32_t __expand_bits(uint32_t v) +{ + v = (v * 0x00010001u) & 0xFF0000FFu; + v = (v * 0x00000101u) & 0x0F00F00Fu; + v = (v * 0x00000011u) & 0xC30C30C3u; + v = (v * 0x00000005u) & 0x49249249u; + return v; +} + +inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) +{ + uint32_t xx = __expand_bits(x); + uint32_t yy = __expand_bits(y); + uint32_t zz = __expand_bits(z); + return xx | (yy << 1) | (zz << 2); +} + +inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) +{ + x = x & 0x49249249; + x = (x | (x >> 2)) & 0xc30c30c3; + x = (x | (x >> 4)) & 0x0f00f00f; + x = (x | (x >> 8)) & 0xff0000ff; + x = (x | (x >> 16)) & 0x0000ffff; + return x; +} + + +//////////////////////////////////////////////////// +///////////// utils ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// nears/fars: [N] +// scalar_t should always be float in use. +template +__global__ void kernel_near_far_from_aabb( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const scalar_t * __restrict__ aabb, + const uint32_t N, + const float min_near, + scalar_t * nears, scalar_t * fars +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // get near far (assume cube scene) + float near = (aabb[0] - ox) * rdx; + float far = (aabb[3] - ox) * rdx; + if (near > far) swapf(near, far); + + float near_y = (aabb[1] - oy) * rdy; + float far_y = (aabb[4] - oy) * rdy; + if (near_y > far_y) swapf(near_y, far_y); + + if (near > far_y || near_y > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_y > near) near = near_y; + if (far_y < far) far = far_y; + + float near_z = (aabb[2] - oz) * rdz; + float far_z = (aabb[5] - oz) * rdz; + if (near_z > far_z) swapf(near_z, far_z); + + if (near > far_z || near_z > far) { + nears[n] = fars[n] = std::numeric_limits::max(); + return; + } + + if (near_z > near) near = near_z; + if (far_z < far) far = far_z; + + if (near < min_near) near = min_near; + + nears[n] = near; + fars[n] = far; +} + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "near_far_from_aabb", ([&] { + kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr()); + })); +} + + +// rays_o/d: [N, 3] +// radius: float +// coords: [N, 2] +template +__global__ void kernel_sph_from_ray( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const float radius, + const uint32_t N, + scalar_t * coords +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + coords += n * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + + // solve t from || o + td || = radius + const float A = dx * dx + dy * dy + dz * dz; + const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2 + const float C = ox * ox + oy * oy + oz * oz - radius * radius; + + const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive) + + // solve theta, phi (assume y is the up axis) + const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz; + const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI) + const float phi = atan2(z, x); // [-PI, PI) + + // normalize to [-1, 1] + coords[0] = 2 * theta * RPI() - 1; + coords[1] = phi * RPI(); +} + + +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "sph_from_ray", ([&] { + kernel_sph_from_ray<<>>(rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr()); + })); +} + + +// coords: int32, [N, 3] +// indices: int32, [N] +__global__ void kernel_morton3D( + const int * __restrict__ coords, + const uint32_t N, + int * indices +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + indices[n] = __morton3D(coords[0], coords[1], coords[2]); +} + + +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D<<>>(coords.data_ptr(), N, indices.data_ptr()); +} + + +// indices: int32, [N] +// coords: int32, [N, 3] +__global__ void kernel_morton3D_invert( + const int * __restrict__ indices, + const uint32_t N, + int * coords +) { + // parallel + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + coords += n * 3; + + const int ind = indices[n]; + + coords[0] = __morton3D_invert(ind >> 0); + coords[1] = __morton3D_invert(ind >> 1); + coords[2] = __morton3D_invert(ind >> 2); +} + + +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) { + static constexpr uint32_t N_THREAD = 128; + kernel_morton3D_invert<<>>(indices.data_ptr(), N, coords.data_ptr()); +} + + +// grid: float, [C, H, H, H] +// N: int, C * H * H * H / 8 +// density_thresh: float +// bitfield: uint8, [N] +template +__global__ void kernel_packbits( + const scalar_t * __restrict__ grid, + const uint32_t N, + const float density_thresh, + uint8_t * bitfield +) { + // parallel per byte + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + grid += n * 8; + + uint8_t bits = 0; + + #pragma unroll + for (uint8_t i = 0; i < 8; i++) { + bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0; + } + + bitfield[n] = bits; +} + + +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grid.scalar_type(), "packbits", ([&] { + kernel_packbits<<>>(grid.data_ptr(), N, density_thresh, bitfield.data_ptr()); + })); +} + +//////////////////////////////////////////////////// +///////////// training ///////////// +//////////////////////////////////////////////////// + +// rays_o/d: [N, 3] +// grid: [CHHH / 8] +// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2] +// dirs: [M, 3] +// rays: [N, 3], idx, offset, num_steps +template +__global__ void kernel_march_rays_train( + const scalar_t * __restrict__ rays_o, + const scalar_t * __restrict__ rays_d, + const uint8_t * __restrict__ grid, + const float bound, + const float dt_gamma, const uint32_t max_steps, + const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas, + int * rays, + int * counter, + const scalar_t* __restrict__ noises +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + rays_o += n * 3; + rays_d += n * 3; + + // ray marching + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + const float near = nears[n]; + const float far = fars[n]; + const float noise = noises[n]; + + const float dt_min = 2 * SQRT3() / max_steps; + const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; + + float t0 = near; + + // perturb + t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise; + + // first pass: estimation of num_steps + float t = t0; + uint32_t num_steps = 0; + + //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far); + + while (t < far && num_steps < max_steps) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1.0f, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps); + + if (occ) { + num_steps++; + t += dt; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } + + //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min); + + // second pass: really locate and write points & dirs + uint32_t point_index = atomicAdd(counter, num_steps); + uint32_t ray_index = atomicAdd(counter + 1, 1); + + //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index); + + // write rays + rays[ray_index * 3] = n; + rays[ray_index * 3 + 1] = point_index; + rays[ray_index * 3 + 2] = num_steps; + + if (num_steps == 0) return; + if (point_index + num_steps > M) return; + + xyzs += point_index * 3; + dirs += point_index * 3; + deltas += point_index * 2; + + t = t0; + uint32_t step = 0; + + float last_t = t; + + while (t < far && step < num_steps) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1.0f, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + // query grid + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + if (occ) { + // write step + xyzs[0] = x; + xyzs[1] = y; + xyzs[2] = z; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + t += dt; + deltas[0] = dt; + deltas[1] = t - last_t; // used to calc depth + last_t = t; + xyzs += 3; + dirs += 3; + deltas += 2; + step++; + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } +} + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays_train", ([&] { + kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr(), fars.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), counter.data_ptr(), noises.data_ptr()); + })); +} + + +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N], final pixel alpha +// depth: [N,] +// image: [N, 3] +template +__global__ void kernel_composite_rays_train_forward( + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * weights_sum, + scalar_t * depth, + scalar_t * image +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps > M) { + weights_sum[index] = 0; + depth[index] = 0; + image[index * 3] = 0; + image[index * 3 + 1] = 0; + image[index * 3 + 2] = 0; + return; + } + + sigmas += offset; + rgbs += offset * 3; + deltas += offset * 2; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + t += deltas[1]; // real delta + d += weight * t; + + ws += weight; + + T *= 1.0f - alpha; + + // minimal remained transmittence + if (T < T_thresh) break; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // write + weights_sum[index] = ws; // weights_sum + depth[index] = d; + image[index * 3] = r; + image[index * 3 + 1] = g; + image[index * 3 + 2] = b; +} + + +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + sigmas.scalar_type(), "composite_rays_train_forward", ([&] { + kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} + + +// grad_weights_sum: [N,] +// grad: [N, 3] +// sigmas: [M] +// rgbs: [M, 3] +// deltas: [M, 2] +// rays: [N, 3], idx, offset, num_steps +// weights_sum: [N,], weights_sum here +// image: [N, 3] +// grad_sigmas: [M] +// grad_rgbs: [M, 3] +template +__global__ void kernel_composite_rays_train_backward( + const scalar_t * __restrict__ grad_weights_sum, + const scalar_t * __restrict__ grad_image, + const scalar_t * __restrict__ sigmas, + const scalar_t * __restrict__ rgbs, + const scalar_t * __restrict__ deltas, + const int * __restrict__ rays, + const scalar_t * __restrict__ weights_sum, + const scalar_t * __restrict__ image, + const uint32_t M, const uint32_t N, const float T_thresh, + scalar_t * grad_sigmas, + scalar_t * grad_rgbs +) { + // parallel per ray + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= N) return; + + // locate + uint32_t index = rays[n * 3]; + uint32_t offset = rays[n * 3 + 1]; + uint32_t num_steps = rays[n * 3 + 2]; + + if (num_steps == 0 || offset + num_steps > M) return; + + grad_weights_sum += index; + grad_image += index * 3; + weights_sum += index; + image += index * 3; + sigmas += offset; + rgbs += offset * 3; + deltas += offset * 2; + grad_sigmas += offset; + grad_rgbs += offset * 3; + + // accumulate + uint32_t step = 0; + + scalar_t T = 1.0f; + const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0]; + scalar_t r = 0, g = 0, b = 0, ws = 0; + + while (step < num_steps) { + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + const scalar_t weight = alpha * T; + + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + ws += weight; + + T *= 1.0f - alpha; + + // check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation. + // write grad_rgbs + grad_rgbs[0] = grad_image[0] * weight; + grad_rgbs[1] = grad_image[1] * weight; + grad_rgbs[2] = grad_image[2] * weight; + + // write grad_sigmas + grad_sigmas[0] = deltas[0] * ( + grad_image[0] * (T * rgbs[0] - (r_final - r)) + + grad_image[1] * (T * rgbs[1] - (g_final - g)) + + grad_image[2] * (T * rgbs[2] - (b_final - b)) + + grad_weights_sum[0] * (1 - ws_final) + ); + + //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); + // minimal remained transmittence + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + grad_sigmas++; + grad_rgbs += 3; + + step++; + } +} + + +void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) { + + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_image.scalar_type(), "composite_rays_train_backward", ([&] { + kernel_composite_rays_train_backward<<>>(grad_weights_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr()); + })); +} + + +//////////////////////////////////////////////////// +///////////// infernce ///////////// +//////////////////////////////////////////////////// + +template +__global__ void kernel_march_rays( + const uint32_t n_alive, + const uint32_t n_step, + const int* __restrict__ rays_alive, + const scalar_t* __restrict__ rays_t, + const scalar_t* __restrict__ rays_o, + const scalar_t* __restrict__ rays_d, + const float bound, + const float dt_gamma, const uint32_t max_steps, + const uint32_t C, const uint32_t H, + const uint8_t * __restrict__ grid, + const scalar_t* __restrict__ nears, + const scalar_t* __restrict__ fars, + scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas, + const scalar_t* __restrict__ noises +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + const float noise = noises[n]; + + // locate + rays_o += index * 3; + rays_d += index * 3; + xyzs += n * n_step * 3; + dirs += n * n_step * 3; + deltas += n * n_step * 2; + + const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; + const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; + const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; + const float rH = 1 / (float)H; + const float H3 = H * H * H; + + float t = rays_t[index]; // current ray's t + const float near = nears[index], far = fars[index]; + + const float dt_min = 2 * SQRT3() / max_steps; + const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; + + // march for n_step steps, record points + uint32_t step = 0; + + // introduce some randomness + t += clamp(t * dt_gamma, dt_min, dt_max) * noise; + + float last_t = t; + + while (t < far && step < n_step) { + // current point + const float x = clamp(ox + t * dx, -bound, bound); + const float y = clamp(oy + t * dy, -bound, bound); + const float z = clamp(oz + t * dz, -bound, bound); + + const float dt = clamp(t * dt_gamma, dt_min, dt_max); + + // get mip level + const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] + + const float mip_bound = fminf(scalbnf(1, level), bound); + const float mip_rbound = 1 / mip_bound; + + // convert to nearest grid position + const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); + + const uint32_t index = level * H3 + __morton3D(nx, ny, nz); + const bool occ = grid[index / 8] & (1 << (index % 8)); + + // if occpuied, advance a small step, and write to output + if (occ) { + // write step + xyzs[0] = x; + xyzs[1] = y; + xyzs[2] = z; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; + // calc dt + t += dt; + deltas[0] = dt; + deltas[1] = t - last_t; // used to calc depth + last_t = t; + // step + xyzs += 3; + dirs += 3; + deltas += 2; + step++; + + // else, skip a large step (basically skip a voxel grid) + } else { + // calc distance to next voxel + const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; + const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; + const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; + const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); + // step until next voxel + do { + t += clamp(t * dt_gamma, dt_min, dt_max); + } while (t < tt); + } + } +} + + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) { + static constexpr uint32_t N_THREAD = 128; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + rays_o.scalar_type(), "march_rays", ([&] { + kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), noises.data_ptr()); + })); +} + + +template +__global__ void kernel_composite_rays( + const uint32_t n_alive, + const uint32_t n_step, + const float T_thresh, + int* rays_alive, + scalar_t* rays_t, + const scalar_t* __restrict__ sigmas, + const scalar_t* __restrict__ rgbs, + const scalar_t* __restrict__ deltas, + scalar_t* weights_sum, scalar_t* depth, scalar_t* image +) { + const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; + if (n >= n_alive) return; + + const int index = rays_alive[n]; // ray id + + // locate + sigmas += n * n_step; + rgbs += n * n_step * 3; + deltas += n * n_step * 2; + + rays_t += index; + weights_sum += index; + depth += index; + image += index * 3; + + scalar_t t = rays_t[0]; // current ray's t + + scalar_t weight_sum = weights_sum[0]; + scalar_t d = depth[0]; + scalar_t r = image[0]; + scalar_t g = image[1]; + scalar_t b = image[2]; + + // accumulate + uint32_t step = 0; + while (step < n_step) { + + // ray is terminated if delta == 0 + if (deltas[0] == 0) break; + + const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); + + /* + T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) + w_i = alpha_i * T_i + --> + T_i = 1 - \sum_{j=0}^{i-1} w_j + */ + const scalar_t T = 1 - weight_sum; + const scalar_t weight = alpha * T; + weight_sum += weight; + + t += deltas[1]; // real delta + d += weight * t; + r += weight * rgbs[0]; + g += weight * rgbs[1]; + b += weight * rgbs[2]; + + //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); + + // ray is terminated if T is too small + // use a larger bound to further accelerate inference + if (T < T_thresh) break; + + // locate + sigmas++; + rgbs += 3; + deltas += 2; + step++; + } + + //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); + + // rays_alive = -1 means ray is terminated early. + if (step < n_step) { + rays_alive[n] = -1; + } else { + rays_t[0] = t; + } + + weights_sum[0] = weight_sum; // this is the thing I needed! + depth[0] = d; + image[0] = r; + image[1] = g; + image[2] = b; +} + + +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) { + static constexpr uint32_t N_THREAD = 128; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + image.scalar_type(), "composite_rays", ([&] { + kernel_composite_rays<<>>(n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr()); + })); +} \ No newline at end of file diff --git a/raymarching/src/raymarching.h b/raymarching/src/raymarching.h new file mode 100755 index 0000000..3a2e692 --- /dev/null +++ b/raymarching/src/raymarching.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + + +void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); +void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); +void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); +void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); +void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); + +void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises); +void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); +void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs); + +void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises); +void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 0000000..598ceec --- /dev/null +++ b/requirements.txt @@ -0,0 +1,56 @@ +tqdm +rich +ninja +numpy +pandas +scipy +scikit-learn +matplotlib +opencv-python +imageio +imageio-ffmpeg + +torch +torch-ema +einops +tensorboard +tensorboardX + +# for gui +dearpygui + +# for grid_tcnn +# git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch + +# for stable-diffusion +huggingface_hub +diffusers >= 0.9.0 +accelerate +transformers + +# for dmtet and mesh export +xatlas +trimesh +PyMCubes +pymeshlab +git+https://github.com/NVlabs/nvdiffrast/ + +# for zero123 +carvekit-colab +omegaconf +pytorch-lightning +taming-transformers-rom1504 +kornia +git+https://github.com/openai/CLIP.git + +# for omnidata +gdown + +# for dpt +timm + +# for remote debugging +debugpy-run + +# for deepfloyd if +sentencepiece \ No newline at end of file diff --git a/scripts/colmap2nerf.py b/scripts/colmap2nerf.py new file mode 100755 index 0000000..b4f463e --- /dev/null +++ b/scripts/colmap2nerf.py @@ -0,0 +1,331 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import argparse +import os +from pathlib import Path, PurePosixPath + +import numpy as np +import json +import sys +import math +import cv2 +import os +import shutil + +def parse_args(): + parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place") + + parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also") + parser.add_argument("--video_fps", default=2) + parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video") + parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder") + parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images") + parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename") + parser.add_argument("--colmap_camera_model", default="OPENCV", choices=["SIMPLE_PINHOLE", "PINHOLE", "SIMPLE_RADIAL", "RADIAL","OPENCV"], help="camera model") + parser.add_argument("--colmap_camera_params", default="", help="intrinsic parameters, depending on the chosen model. Format: fx,fy,cx,cy,dist") + parser.add_argument("--images", default="images", help="input path to the images") + parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)") + parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16") + parser.add_argument("--skip_early", default=0, help="skip this many images from the start") + parser.add_argument("--keep_colmap_coords", action="store_true", help="keep transforms.json in COLMAP's original frame of reference (this will avoid reorienting and repositioning the scene for preview and rendering)") + parser.add_argument("--out", default="transforms.json", help="output path") + parser.add_argument("--vocab_path", default="", help="vocabulary tree path") + args = parser.parse_args() + return args + +def do_system(arg): + print(f"==== running: {arg}") + err = os.system(arg) + if err: + print("FATAL: command failed") + sys.exit(err) + +def run_ffmpeg(args): + if not os.path.isabs(args.images): + args.images = os.path.join(os.path.dirname(args.video_in), args.images) + images = "\"" + args.images + "\"" + video = "\"" + args.video_in + "\"" + fps = float(args.video_fps) or 1.0 + print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.") + if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": + sys.exit(1) + try: + # Passing Images' Path Without Double Quotes + shutil.rmtree(args.images) + except: + pass + do_system(f"mkdir {images}") + + time_slice_value = "" + time_slice = args.time_slice + if time_slice: + start, end = time_slice.split(",") + time_slice_value = f",select='between(t\,{start}\,{end})'" + do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg") + +def run_colmap(args): + db = args.colmap_db + images = "\"" + args.images + "\"" + db_noext=str(Path(db).with_suffix("")) + + if args.text=="text": + args.text=db_noext+"_text" + text=args.text + sparse=db_noext+"_sparse" + print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}") + if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": + sys.exit(1) + if os.path.exists(db): + os.remove(db) + do_system(f"colmap feature_extractor --ImageReader.camera_model {args.colmap_camera_model} --ImageReader.camera_params \"{args.colmap_camera_params}\" --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}") + match_cmd = f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}" + if args.vocab_path: + match_cmd += f" --VocabTreeMatching.vocab_tree_path {args.vocab_path}" + do_system(match_cmd) + try: + shutil.rmtree(sparse) + except: + pass + do_system(f"mkdir {sparse}") + do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}") + do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1") + try: + shutil.rmtree(text) + except: + pass + do_system(f"mkdir {text}") + do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT") + +def variance_of_laplacian(image): + return cv2.Laplacian(image, cv2.CV_64F).var() + +def sharpness(imagePath): + image = cv2.imread(imagePath) + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + fm = variance_of_laplacian(gray) + return fm + +def qvec2rotmat(qvec): + return np.array([ + [ + 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] + ], [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] + ], [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 + ] + ]) + +def rotmat(a, b): + a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) + v = np.cross(a, b) + c = np.dot(a, b) + # handle exception for the opposite direction input + if c < -1 + 1e-10: + return rotmat(a + np.random.uniform(-1e-2, 1e-2, 3), b) + s = np.linalg.norm(v) + kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) + return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) + +def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel + da = da / np.linalg.norm(da) + db = db / np.linalg.norm(db) + c = np.cross(da, db) + denom = np.linalg.norm(c)**2 + t = ob - oa + ta = np.linalg.det([t, db, c]) / (denom + 1e-10) + tb = np.linalg.det([t, da, c]) / (denom + 1e-10) + if ta > 0: + ta = 0 + if tb > 0: + tb = 0 + return (oa+ta*da+ob+tb*db) * 0.5, denom + +if __name__ == "__main__": + args = parse_args() + if args.video_in != "": + run_ffmpeg(args) + if args.run_colmap: + run_colmap(args) + AABB_SCALE = int(args.aabb_scale) + SKIP_EARLY = int(args.skip_early) + IMAGE_FOLDER = args.images + TEXT_FOLDER = args.text + OUT_PATH = args.out + print(f"outputting to {OUT_PATH}...") + with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f: + angle_x = math.pi / 2 + for line in f: + # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691 + # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224 + # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443 + if line[0] == "#": + continue + els = line.split(" ") + w = float(els[2]) + h = float(els[3]) + fl_x = float(els[4]) + fl_y = float(els[4]) + k1 = 0 + k2 = 0 + p1 = 0 + p2 = 0 + cx = w / 2 + cy = h / 2 + if els[1] == "SIMPLE_PINHOLE": + cx = float(els[5]) + cy = float(els[6]) + elif els[1] == "PINHOLE": + fl_y = float(els[5]) + cx = float(els[6]) + cy = float(els[7]) + elif els[1] == "SIMPLE_RADIAL": + cx = float(els[5]) + cy = float(els[6]) + k1 = float(els[7]) + elif els[1] == "RADIAL": + cx = float(els[5]) + cy = float(els[6]) + k1 = float(els[7]) + k2 = float(els[8]) + elif els[1] == "OPENCV": + fl_y = float(els[5]) + cx = float(els[6]) + cy = float(els[7]) + k1 = float(els[8]) + k2 = float(els[9]) + p1 = float(els[10]) + p2 = float(els[11]) + else: + print("unknown camera model ", els[1]) + # fl = 0.5 * w / tan(0.5 * angle_x); + angle_x = math.atan(w / (fl_x * 2)) * 2 + angle_y = math.atan(h / (fl_y * 2)) * 2 + fovx = angle_x * 180 / math.pi + fovy = angle_y * 180 / math.pi + + print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ") + + with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f: + i = 0 + bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4]) + out = { + "camera_angle_x": angle_x, + "camera_angle_y": angle_y, + "fl_x": fl_x, + "fl_y": fl_y, + "k1": k1, + "k2": k2, + "p1": p1, + "p2": p2, + "cx": cx, + "cy": cy, + "w": w, + "h": h, + "aabb_scale": AABB_SCALE, + "frames": [], + } + + up = np.zeros(3) + for line in f: + line = line.strip() + if line[0] == "#": + continue + i = i + 1 + if i < SKIP_EARLY*2: + continue + if i % 2 == 1: + elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces) + #name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9]))) + # why is this requireing a relitive path while using ^ + image_rel = os.path.relpath(IMAGE_FOLDER) + name = str(f"./{image_rel}/{'_'.join(elems[9:])}") + b=sharpness(name) + print(name, "sharpness=",b) + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + R = qvec2rotmat(-qvec) + t = tvec.reshape([3,1]) + m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) + c2w = np.linalg.inv(m) + if not args.keep_colmap_coords: + c2w[0:3,2] *= -1 # flip the y and z axis + c2w[0:3,1] *= -1 + c2w = c2w[[1,0,2,3],:] # swap y and z + c2w[2,:] *= -1 # flip whole world upside down + + up += c2w[0:3,1] + + frame={"file_path":name,"sharpness":b,"transform_matrix": c2w} + out["frames"].append(frame) + nframes = len(out["frames"]) + + if args.keep_colmap_coords: + flip_mat = np.array([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1] + ]) + + for f in out["frames"]: + f["transform_matrix"] = np.matmul(f["transform_matrix"], flip_mat) # flip cameras (it just works) + else: + # don't keep colmap coords - reorient the scene to be easier to work with + + up = up / np.linalg.norm(up) + print("up vector was", up) + R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1] + R = np.pad(R,[0,1]) + R[-1, -1] = 1 + + for f in out["frames"]: + f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis + + # find a central point they are all looking at + print("computing center of attention...") + totw = 0.0 + totp = np.array([0.0, 0.0, 0.0]) + for f in out["frames"]: + mf = f["transform_matrix"][0:3,:] + for g in out["frames"]: + mg = g["transform_matrix"][0:3,:] + p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) + if w > 0.00001: + totp += p*w + totw += w + if totw > 0.0: + totp /= totw + print(totp) # the cameras are looking at totp + for f in out["frames"]: + f["transform_matrix"][0:3,3] -= totp + + avglen = 0. + for f in out["frames"]: + avglen += np.linalg.norm(f["transform_matrix"][0:3,3]) + avglen /= nframes + print("avg camera distance from origin", avglen) + for f in out["frames"]: + f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized" + + for f in out["frames"]: + f["transform_matrix"] = f["transform_matrix"].tolist() + print(nframes,"frames") + print(f"writing {OUT_PATH}") + with open(OUT_PATH, "w") as outfile: + json.dump(out, outfile, indent=2) diff --git a/scripts/install_ext.sh b/scripts/install_ext.sh new file mode 100755 index 0000000..228190e --- /dev/null +++ b/scripts/install_ext.sh @@ -0,0 +1,4 @@ +pip install ./raymarching +pip install ./shencoder +pip install ./freqencoder +pip install ./gridencoder \ No newline at end of file diff --git a/scripts/run.sh b/scripts/run.sh new file mode 100755 index 0000000..442196c --- /dev/null +++ b/scripts/run.sh @@ -0,0 +1,5 @@ +#! /bin/bash + +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of cthulhu" --workspace trial_cthulhu +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a squirrel" --workspace trial_squirrel +CUDA_VISIBLE_DEVICES=1 python main.py -O --text "a DSLR photo of a cat lying on its side batting at a ball of yarn" --workspace trial_cat_lying \ No newline at end of file diff --git a/shencoder/__init__.py b/shencoder/__init__.py new file mode 100755 index 0000000..2b55c96 --- /dev/null +++ b/shencoder/__init__.py @@ -0,0 +1 @@ +from .sphere_harmonics import SHEncoder \ No newline at end of file diff --git a/shencoder/backend.py b/shencoder/backend.py new file mode 100755 index 0000000..cc08a3e --- /dev/null +++ b/shencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_sh_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'shencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/shencoder/setup.py b/shencoder/setup.py new file mode 100755 index 0000000..342a601 --- /dev/null +++ b/shencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='shencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_shencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'shencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/shencoder/sphere_harmonics.py b/shencoder/sphere_harmonics.py new file mode 100755 index 0000000..7bab24e --- /dev/null +++ b/shencoder/sphere_harmonics.py @@ -0,0 +1,87 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _shencoder as _backend +except ImportError: + from .backend import _backend + +class _sh_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, calc_grad_inputs=False): + # inputs: [B, input_dim], float in [-1, 1] + # RETURN: [B, F], float + + inputs = inputs.contiguous() + B, input_dim = inputs.shape # batch size, coord dim + output_dim = degree ** 2 + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + if calc_grad_inputs: + dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) + else: + dy_dx = None + + _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) + + ctx.save_for_backward(inputs, dy_dx) + ctx.dims = [B, input_dim, degree] + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + inputs, dy_dx = ctx.saved_tensors + + if dy_dx is not None: + grad = grad.contiguous() + B, input_dim, degree = ctx.dims + grad_inputs = torch.zeros_like(inputs) + _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) + return grad_inputs, None, None + else: + return None, None, None + + + +sh_encode = _sh_encoder.apply + + +class SHEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim # coord dims, must be 3 + self.degree = degree # 0 ~ 4 + self.output_dim = degree ** 2 + + assert self.input_dim == 3, "SH encoder only support input dim == 3" + assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" + + def __repr__(self): + return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" + + def forward(self, inputs, size=1): + # inputs: [..., input_dim], normalized real world positions in [-size, size] + # return: [..., degree^2] + + inputs = inputs / size # [-1, 1] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = sh_encode(inputs, self.degree, inputs.requires_grad) + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs \ No newline at end of file diff --git a/shencoder/src/bindings.cpp b/shencoder/src/bindings.cpp new file mode 100755 index 0000000..595b5b3 --- /dev/null +++ b/shencoder/src/bindings.cpp @@ -0,0 +1,8 @@ +#include + +#include "shencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); + m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); +} \ No newline at end of file diff --git a/shencoder/src/shencoder.cu b/shencoder/src/shencoder.cu new file mode 100755 index 0000000..a92e4ab --- /dev/null +++ b/shencoder/src/shencoder.cu @@ -0,0 +1,439 @@ +#include + +#include +#include +#include + +#include +#include + +#include +#include + +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +template +__host__ __device__ T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__global__ void kernel_sh( + const scalar_t * __restrict__ inputs, + scalar_t * outputs, + uint32_t B, uint32_t D, uint32_t C, + scalar_t * dy_dx +) { + const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x; + if (b >= B) return; + + const uint32_t C2 = C * C; + + // locate + inputs += b * D; + outputs += b * C2; + + scalar_t x = inputs[0], y = inputs[1], z = inputs[2]; + + scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z; + scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2; + scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2; + + auto write_sh = [&]() { + outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi)) + if (C <= 1) { return; } + outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi)) + outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi)) + outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi)) + if (C <= 2) { return; } + outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi)) + outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi)) + outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi)) + outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi)) + outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi)) + if (C <= 3) { return; } + outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi)) + outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi)) + outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi)) + outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi)) + outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi)) + outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + if (C <= 4) { return; } + outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi)) + outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi)) + outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi)) + outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi)) + outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi)) + outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi)) + outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi)) + outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi)) + outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + if (C <= 5) { return; } + outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi)) + outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi)) + outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi)) + outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi)) + outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi)) + outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + if (C <= 6) { return; } + outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi)) + outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi)) + outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi)) + outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + if (C <= 7) { return; } + outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi)) + outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi)) + outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi)) + outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi)) + outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi)) + outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi)) + }; + + write_sh(); + + if (dy_dx) { + scalar_t *dx = dy_dx + b * D * C2; + scalar_t *dy = dx + C2; + scalar_t *dz = dy + C2; + + auto write_sh_dx = [&]() { + dx[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dx[1] = 0.0f ; // 0 + dx[2] = 0.0f ; // 0 + dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) + if (C <= 2) { return; } + dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi)) + dx[5] = 0.0f ; // 0 + dx[6] = 0.0f ; // 0 + dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) + dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) + if (C <= 3) { return; } + dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi)) + dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi)) + dx[11] = 0.0f ; // 0 + dx[12] = 0.0f ; // 0 + dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) + dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + if (C <= 4) { return; } + dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi)) + dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi)) + dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi)) + dx[19] = 0.0f ; // 0 + dx[20] = 0.0f ; // 0 + dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + if (C <= 5) { return; } + dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi)) + dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi)) + dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi)) + dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi)) + dx[29] = 0.0f ; // 0 + dx[30] = 0.0f ; // 0 + dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi)) + dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + if (C <= 6) { return; } + dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi)) + dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi)) + dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dx[41] = 0.0f ; // 0 + dx[42] = 0.0f ; // 0 + dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + if (C <= 7) { return; } + dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi)) + dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) + dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi)) + dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dx[55] = 0.0f ; // 0 + dx[56] = 0.0f ; // 0 + dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi)) + dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi)) + dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + }; + + auto write_sh_dy = [&]() { + dy[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) + dy[2] = 0.0f ; // 0 + dy[3] = 0.0f ; // 0 + if (C <= 2) { return; } + dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) + dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) + dy[6] = 0.0f ; // 0 + dy[7] = 0.0f ; // 0 + dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) + if (C <= 3) { return; } + dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) + dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) + dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) + dy[12] = 0.0f ; // 0 + dy[13] = 0.0f ; // 0 + dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi)) + dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi)) + if (C <= 4) { return; } + dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) + dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) + dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) + dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) + dy[20] = 0.0f ; // 0 + dy[21] = 0.0f ; // 0 + dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi)) + dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi)) + dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi)) + if (C <= 5) { return; } + dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) + dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) + dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) + dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) + dy[30] = 0.0f ; // 0 + dy[31] = 0.0f ; // 0 + dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi)) + dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi)) + dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi)) + dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi)) + if (C <= 6) { return; } + dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) + dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) + dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) + dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) + dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) + dy[42] = 0.0f ; // 0 + dy[43] = 0.0f ; // 0 + dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi)) + dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi)) + dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) + dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi)) + dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + if (C <= 7) { return; } + dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) + dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) + dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi)) + dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) + dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) + dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) + dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) + dy[56] = 0.0f ; // 0 + dy[57] = 0.0f ; // 0 + dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) + dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) + dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + }; + + auto write_sh_dz = [&]() { + dz[0] = 0.0f ; // 0 + if (C <= 1) { return; } + dz[1] = 0.0f ; // 0 + dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi)) + dz[3] = 0.0f ; // 0 + if (C <= 2) { return; } + dz[4] = 0.0f ; // 0 + dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) + dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi)) + dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi)) + dz[8] = 0.0f ; // 0 + if (C <= 3) { return; } + dz[9] = 0.0f ; // 0 + dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi)) + dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi)) + dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi)) + dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi)) + dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi)) + dz[15] = 0.0f ; // 0 + if (C <= 4) { return; } + dz[16] = 0.0f ; // 0 + dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) + dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi)) + dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi)) + dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi)) + dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi)) + dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi)) + dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) + dz[24] = 0.0f ; // 0 + if (C <= 5) { return; } + dz[25] = 0.0f ; // 0 + dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi)) + dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi)) + dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi)) + dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi)) + dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi)) + dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi)) + dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi)) + dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi)) + dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + dz[35] = 0.0f ; // 0 + if (C <= 6) { return; } + dz[36] = 0.0f ; // 0 + dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi)) + dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi)) + dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi)) + dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi)) + dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) + dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi)) + dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi)) + dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) + dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + dz[48] = 0.0f ; // 0 + if (C <= 7) { return; } + dz[49] = 0.0f ; // 0 + dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) + dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) + dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi)) + dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi)) + dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) + dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi)) + dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) + dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi)) + dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi)) + dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) + dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) + dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) + dz[63] = 0.0f ; // 0 + }; + write_sh_dx(); + write_sh_dy(); + write_sh_dz(); + } +} + + +template +__global__ void kernel_sh_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ inputs, + uint32_t B, uint32_t D, uint32_t C, + const scalar_t * __restrict__ dy_dx, + scalar_t * grad_inputs +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + const uint32_t b = t / D; + if (b >= B) return; + + const uint32_t d = t - b * D; + const uint32_t C2 = C * C; + + // locate + grad += b * C2; + dy_dx += b * D * C2 + d * C2; + + for (int ch = 0; ch < C2; ch++) { + grad_inputs[t] += grad[ch] * dy_dx[ch]; + //printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]); + } + +} + +// inputs: [B, D], float, in [0, 1] +// outputs: [B, L * C], float +template +void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) { + static constexpr uint32_t N_THREADS = 256; + kernel_sh<<>>(inputs, outputs, B, D, C, dy_dx); +} + + +template +void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) { + static constexpr uint32_t N_THREADS = 256; + kernel_sh_backward<<>>(grad, inputs, B, D, C, dy_dx, grad_inputs); +} + + +void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx) { + CHECK_CUDA(inputs); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + inputs.scalar_type(), "sh_encode_forward_cuda", ([&] { + sh_encode_forward_cuda(inputs.data_ptr(), outputs.data_ptr(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr); + })); +} + +void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(dy_dx); + CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(dy_dx); + CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(dy_dx); + CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "sh_encode_backward_cuda", ([&] { + sh_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), B, D, C, dy_dx.data_ptr(), grad_inputs.data_ptr()); + })); +} \ No newline at end of file diff --git a/shencoder/src/shencoder.h b/shencoder/src/shencoder.h new file mode 100755 index 0000000..f9e89fa --- /dev/null +++ b/shencoder/src/shencoder.h @@ -0,0 +1,10 @@ +# pragma once + +#include +#include + +// inputs: [B, D], float, in [-1, 1] +// outputs: [B, F], float + +void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx); +void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); \ No newline at end of file