Skip to content

Commit

Permalink
add files
Browse files Browse the repository at this point in the history
  • Loading branch information
StarsTesla committed Aug 24, 2023
1 parent f8ddc2c commit 7c98d9a
Show file tree
Hide file tree
Showing 47 changed files with 7,866 additions and 0 deletions.
Empty file modified .gitignore
100644 → 100755
Empty file.
Empty file modified LICENSE
100644 → 100755
Empty file.
Empty file modified README.md
100644 → 100755
Empty file.
18 changes: 18 additions & 0 deletions activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd

class _trunc_exp(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float)
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)

@staticmethod
@custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(max=15))

trunc_exp = _trunc_exp.apply
75 changes: 75 additions & 0 deletions encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class FreqEncoder_torch(nn.Module):
def __init__(self, input_dim, max_freq_log2, N_freqs,
log_sampling=True, include_input=True,
periodic_fns=(torch.sin, torch.cos)):

super().__init__()

self.input_dim = input_dim
self.include_input = include_input
self.periodic_fns = periodic_fns

self.output_dim = 0
if self.include_input:
self.output_dim += self.input_dim

self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)

if log_sampling:
self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs)
else:
self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs)

self.freq_bands = self.freq_bands.numpy().tolist()

def forward(self, input, **kwargs):

out = []
if self.include_input:
out.append(input)

for i in range(len(self.freq_bands)):
freq = self.freq_bands[i]
for p_fn in self.periodic_fns:
out.append(p_fn(input * freq))

out = torch.cat(out, dim=-1)

return out

def get_encoder(encoding, input_dim=3,
multires=6,
degree=4,
num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False,
**kwargs):

if encoding == 'None':
return lambda x, **kwargs: x, input_dim

elif encoding == 'frequency_torch':
encoder = FreqEncoder_torch(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)

elif encoding == 'frequency': # CUDA implementation, faster than torch.
from freqencoder import FreqEncoder
encoder = FreqEncoder(input_dim=input_dim, degree=multires)

elif encoding == 'sphere_harmonics':
from shencoder import SHEncoder
encoder = SHEncoder(input_dim=input_dim, degree=degree)

elif encoding == 'hashgrid':
from gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='hash', align_corners=align_corners)

elif encoding == 'tiledgrid':
from gridencoder import GridEncoder
encoder = GridEncoder(input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype='tiled', align_corners=align_corners)

else:
raise NotImplementedError('Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]')

return encoder, encoder.output_dim
1 change: 1 addition & 0 deletions freqencoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .freq import FreqEncoder
41 changes: 41 additions & 0 deletions freqencoder/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
from torch.utils.cpp_extension import load

_src_path = os.path.dirname(os.path.abspath(__file__))

nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
'-use_fast_math'
]

if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']

# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]

# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path

_backend = load(name='_freqencoder',
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[os.path.join(_src_path, 'src', f) for f in [
'freqencoder.cu',
'bindings.cpp',
]],
)

__all__ = ['_backend']
77 changes: 77 additions & 0 deletions freqencoder/freq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd

try:
import _freqencoder as _backend
except ImportError:
from .backend import _backend


class _freq_encoder(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # force float32 for better precision
def forward(ctx, inputs, degree, output_dim):
# inputs: [B, input_dim], float
# RETURN: [B, F], float

if not inputs.is_cuda: inputs = inputs.cuda()
inputs = inputs.contiguous()

B, input_dim = inputs.shape # batch size, coord dim

outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)

_backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)

ctx.save_for_backward(inputs, outputs)
ctx.dims = [B, input_dim, degree, output_dim]

return outputs

@staticmethod
#@once_differentiable
@custom_bwd
def backward(ctx, grad):
# grad: [B, C * C]

grad = grad.contiguous()
inputs, outputs = ctx.saved_tensors
B, input_dim, degree, output_dim = ctx.dims

grad_inputs = torch.zeros_like(inputs)
_backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs)

return grad_inputs, None, None


freq_encode = _freq_encoder.apply


class FreqEncoder(nn.Module):
def __init__(self, input_dim=3, degree=4):
super().__init__()

self.input_dim = input_dim
self.degree = degree
self.output_dim = input_dim + input_dim * 2 * degree

def __repr__(self):
return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}"

def forward(self, inputs, **kwargs):
# inputs: [..., input_dim]
# return: [..., ]

prefix_shape = list(inputs.shape[:-1])
inputs = inputs.reshape(-1, self.input_dim)

outputs = freq_encode(inputs, self.degree, self.output_dim)

outputs = outputs.reshape(prefix_shape + [self.output_dim])

return outputs
51 changes: 51 additions & 0 deletions freqencoder/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

_src_path = os.path.dirname(os.path.abspath(__file__))

nvcc_flags = [
'-O3', '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
'-use_fast_math'
]

if os.name == "posix":
c_flags = ['-O3', '-std=c++14']
elif os.name == "nt":
c_flags = ['/O2', '/std:c++17']

# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True)
if paths:
return paths[0]

# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
os.environ["PATH"] += ";" + cl_path

setup(
name='freqencoder', # package name, import this to use python API
ext_modules=[
CUDAExtension(
name='_freqencoder', # extension name, import this to use CUDA API
sources=[os.path.join(_src_path, 'src', f) for f in [
'freqencoder.cu',
'bindings.cpp',
]],
extra_compile_args={
'cxx': c_flags,
'nvcc': nvcc_flags,
}
),
],
cmdclass={
'build_ext': BuildExtension,
}
)
8 changes: 8 additions & 0 deletions freqencoder/src/bindings.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <torch/extension.h>

#include "freqencoder.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)");
m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)");
}
Loading

0 comments on commit 7c98d9a

Please sign in to comment.