-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathpositional_encoding.py
148 lines (123 loc) · 4.94 KB
/
positional_encoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import numpy as np
import torch.nn as nn
import math
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.multires = kwargs['num_freqs']
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
# print(p_fn)
embed_fns.append(lambda x, p_fn=p_fn,
freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
# print(len(self.embed_fns))
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, include_input=False, input_dims=2, max_freq_log2=10, log_sampling=True):
embed_kwargs = {
'include_input': include_input,
'input_dims': input_dims,
'max_freq_log2': max_freq_log2-1,
'num_freqs': multires,
'log_sampling': log_sampling,
'periodic_fns': [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
def embed(x, eo=embedder_obj): return eo.embed(x)
return embed, embedder_obj.out_dim
class AnnealedHash(nn.Module):
def __init__(self, in_channels, annealed_step, annealed_begin_step=0):
"""
Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
in_channels: number of input channels (3 for both xyz and direction)
"""
super(AnnealedHash, self).__init__()
self.N_freqs = 16
self.in_channels = in_channels
self.annealed = True
self.annealed_step = annealed_step
self.annealed_begin_step = annealed_begin_step
self.index = torch.linspace(0, self.N_freqs - 1, self.N_freqs)
self.index_2 = self.index.view(-1, 1).repeat(1, 2).view(-1)
def forward(self, x_embed, step):
"""
Embeds x to (x, sin(2^k x), cos(2^k x), ...)
Different from the paper, "x" is also in the output
See https://github.com/bmild/nerf/issues/12
Inputs:
x: (B, self.in_channels)
Outputs:
out: (B, self.out_channels)
"""
if self.annealed_begin_step == 0:
# calculate the w for each freq bands
alpha = self.N_freqs * step / float(self.annealed_step)
else:
if step <= self.annealed_begin_step:
alpha = 0
else:
alpha = self.N_freqs * (step - self.annealed_begin_step) / float(
self.annealed_step)
w = (1 - torch.cos(math.pi * torch.clamp(alpha * torch.ones_like(self.index_2) - self.index_2, 0, 1))) / 2
out = x_embed * w.to(x_embed.device)
return out
class AnnealedEmbedding(nn.Module):
def __init__(self, in_channels, N_freqs, annealed_step, annealed_begin_step=0, logscale=True):
"""
Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
in_channels: number of input channels (3 for both xyz and direction)
"""
super(AnnealedEmbedding, self).__init__()
self.N_freqs = N_freqs
self.in_channels = in_channels
self.annealed = True
self.annealed_step = annealed_step
self.annealed_begin_step = annealed_begin_step
self.funcs = [torch.sin, torch.cos]
self.out_channels = in_channels*(len(self.funcs)*N_freqs+1)
self.index = torch.linspace(0, N_freqs-1, N_freqs)
def forward(self, x, step):
"""
Embeds x to (x, sin(2^k x), cos(2^k x), ...)
Different from the paper, "x" is also in the output
See https://github.com/bmild/nerf/issues/12
Inputs:
x: (B, self.in_channels)
Outputs:
out: (B, self.out_channels)
"""
weights = []
if self.annealed_begin_step == 0:
# calculate the w for each freq bands
alpha = self.N_freqs * step / float(self.annealed_step)
else:
if step <= self.annealed_begin_step:
alpha = 0
else:
alpha = self.N_freqs * (step - self.annealed_begin_step) / float(
self.annealed_step)
for j in range(self.N_freqs):
w = (1 - torch.cos(math.pi * torch.clamp(alpha - self.index[j], 0, 1))) / 2
for func in self.funcs:
weights += [w.unsqueeze(-1)]
weights = torch.cat(weights, -1)
out = x * w.to(x.device)
return out