-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathidecoder.py
98 lines (76 loc) · 3.06 KB
/
idecoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import Callable, Tuple
import torch
from einops import repeat
from torch import Tensor, nn
class CoordsEncoder:
def __init__(
self,
input_dims: int = 3,
include_input: bool = True,
max_freq_log2: int = 9,
num_freqs: int = 10,
log_sampling: bool = True,
periodic_fns: Tuple[Callable, Callable] = (torch.sin, torch.cos),
) -> None:
self.input_dims = input_dims
self.include_input = include_input
self.max_freq_log2 = max_freq_log2
self.num_freqs = num_freqs
self.log_sampling = log_sampling
self.periodic_fns = periodic_fns
self.create_embedding_fn()
def create_embedding_fn(self) -> None:
embed_fns = []
d = self.input_dims
out_dim = 0
if self.include_input:
embed_fns.append(lambda x: x)
out_dim += d
if self.log_sampling:
freq_bands = 2.0 ** torch.linspace(0.0, self.max_freq_log2, steps=self.num_freqs)
else:
freq_bands = torch.linspace(2.0**0.0, 2.0**self.max_freq_log2, steps=self.num_freqs)
for freq in freq_bands:
for p_fn in self.periodic_fns:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs: Tensor) -> Tensor:
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
class ImplicitDecoder(nn.Module):
def __init__(
self,
embed_dim: int,
in_dim: int,
hidden_dim: int,
num_hidden_layes_before_skip: int,
num_hidden_layes_after_skip: int,
out_dim: int,
) -> None:
super().__init__()
self.coords_enc = CoordsEncoder(in_dim)
coords_dim = self.coords_enc.out_dim
self.in_layer = nn.Sequential(nn.Linear(embed_dim + coords_dim, hidden_dim), nn.ReLU())
self.skip_proj = nn.Sequential(nn.Linear(embed_dim + coords_dim, hidden_dim), nn.ReLU())
before_skip = []
for _ in range(num_hidden_layes_before_skip):
before_skip.append(nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU()))
self.before_skip = nn.Sequential(*before_skip)
after_skip = []
for _ in range(num_hidden_layes_after_skip):
after_skip.append(nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU()))
after_skip.append(nn.Linear(hidden_dim, out_dim))
self.after_skip = nn.Sequential(*after_skip)
def forward(self, embeddings: Tensor, coords: Tensor) -> Tensor:
# embeddings (B, D1)
# coords (B, N, D2)
coords = self.coords_enc.embed(coords)
repeated_embeddings = repeat(embeddings, "b d -> b n d", n=coords.shape[1])
emb_and_coords = torch.cat([repeated_embeddings, coords], dim=-1)
x = self.in_layer(emb_and_coords)
x = self.before_skip(x)
inp_proj = self.skip_proj(emb_and_coords)
x = x + inp_proj
x = self.after_skip(x)
return x.squeeze(-1)