Skip to content

Commit 78deae9

Browse files
committed
Trains a vMF VAE, adds regularized uncertainty
1 parent 22033ca commit 78deae9

File tree

6 files changed

+170
-31
lines changed

6 files changed

+170
-31
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
models/
2+
motion69_06.npz

examples/black_box_random_geometries/von_mises_fisher_example/main.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

examples/black_box_random_geometries/von_mises_fisher_example/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def run(
9999
n_bones = len(bones)
100100

101101
# Loads up the model.
102-
model = VAE_Motion(n_bones=n_bones, n_hidden=n_hidden)
102+
model = VAE_Motion(n_bones=n_bones, n_hidden=n_hidden, radii=radii)
103103
print(model)
104104

105105
# Trains

examples/black_box_random_geometries/von_mises_fisher_example/vae_motion.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
from stochman.manifold import StatisticalManifold
1313

1414

15-
class VAE_Motion(torch.nn.Module, StatisticalManifold):
16-
def __init__(self, n_bones: int, n_hidden: int) -> None:
15+
class VAE_Motion(torch.nn.Module):
16+
def __init__(self, n_bones: int, n_hidden: int, radii: torch.Tensor) -> None:
1717
super().__init__()
1818
self.n_bones = n_bones
1919
self.n_hidden = n_hidden
2020
self.input_dim = n_bones * 3
21+
self.radii = torch.from_numpy(radii).unsqueeze(0).type(torch.float)
2122

2223
# An encoder for N(mu, sigma) in latent space (dim 2)
2324
self.encoder = nn.Linear(self.input_dim, self.n_hidden)
@@ -73,7 +74,8 @@ def forward(self, x: torch.Tensor):
7374
return q_z_given_x, p_x_given_z
7475

7576
def elbo_loss(self, x: torch.Tensor, q_z_given_x: Normal, p_x_given_z: VonMisesFisher):
76-
rec_loss = -p_x_given_z.log_prob(x).sum(dim=1)
77+
rec_loss = -(self.radii * p_x_given_z.log_prob(x)).sum(dim=1)
7778
kl = kl_divergence(q_z_given_x, self.p_z).sum(dim=1)
7879

79-
return (rec_loss + kl).mean()
80+
beta = 0.0001
81+
return (rec_loss + beta * kl).mean()
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import torch
2+
import torch.nn as nn
3+
from sklearn.cluster import KMeans
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
7+
from stochman.manifold import StatisticalManifold
8+
9+
from vae_motion import VAE_Motion
10+
from vmf import VonMisesFisher
11+
from data_utils import load_bones_data
12+
13+
14+
class TranslatedSigmoid(nn.Module):
15+
"""
16+
A translated sigmoid function that is used
17+
to regulate entropy in entropy networks.
18+
19+
Input:
20+
- beta: float (1.5 by default). The lower,
21+
the narrower the region of low entropy.
22+
Negative values give nice latent spaces
23+
"""
24+
25+
def __init__(self, beta: float = -1.5) -> None:
26+
super(TranslatedSigmoid, self).__init__()
27+
self.beta = nn.Parameter(torch.tensor([beta]))
28+
29+
def forward(self, x):
30+
beta = torch.nn.functional.softplus(self.beta)
31+
alpha = -beta * (6.9077542789816375)
32+
val = torch.sigmoid((x + alpha) / beta)
33+
34+
return val
35+
36+
37+
class VAE_motion_UQ(VAE_Motion, StatisticalManifold):
38+
def __init__(self, n_bones: int, n_hidden: int, radii: torch.Tensor) -> None:
39+
super().__init__(n_bones, n_hidden, radii)
40+
41+
self.n_clusters = 500
42+
self.beta = -18.5
43+
self.limit_k = 0.1
44+
45+
self.translated_sigmoid = None
46+
self.encodings = None
47+
self.cluster_centers = None
48+
49+
def update_cluster_centers(self, training_data: torch.Tensor):
50+
"""
51+
Encodes the training data and
52+
runs KMeans to get centers.
53+
"""
54+
self.translated_sigmoid = TranslatedSigmoid(self.beta)
55+
56+
batch_size, n_bones, _ = training_data.shape
57+
flat_training_data = training_data.reshape(batch_size, n_bones * 3)
58+
self.encodings = self.encode(flat_training_data).mean
59+
Z = self.encodings.detach().numpy()
60+
kmeans = KMeans(n_clusters=self.n_clusters)
61+
kmeans.fit(Z)
62+
self.cluster_centers = torch.from_numpy(kmeans.cluster_centers_)
63+
64+
def min_distance(self, z: torch.Tensor) -> torch.Tensor:
65+
"""
66+
V(z) in the notation of the paper
67+
68+
TODO: clean or move to utilities
69+
"""
70+
# What's the size of z?
71+
# |z| = (batch, zdim), right?
72+
zsh = z.shape
73+
z = z.view(-1, z.shape[-1]) # Nx(zdim)
74+
75+
z_norm = (z ** 2).sum(1, keepdim=True) # Nx1
76+
center_norm = (self.cluster_centers ** 2).sum(1).view(1, -1) # 1x(num_clusters)
77+
d2 = (
78+
z_norm + center_norm - 2.0 * torch.mm(z, self.cluster_centers.transpose(0, 1))
79+
) # Nx(num_clusters)
80+
d2.clamp_(min=0.0) # Nx(num_clusters)
81+
min_dist, _ = d2.min(dim=1) # N
82+
return min_dist.view(zsh[:-1])
83+
84+
def similarity(self, v: torch.Tensor) -> torch.Tensor:
85+
"""
86+
T(z) or alpha in the notation of the paper, but
87+
backwards and translated to make it 0 at 0 and
88+
1 and infty.
89+
"""
90+
return self.translated_sigmoid(v)
91+
92+
def decode(self, z: torch.Tensor) -> VonMisesFisher:
93+
"""
94+
An alternate version of the decoder that pushes
95+
k to self.limit_k away of the support of the
96+
data.
97+
"""
98+
zsh = z.shape
99+
z = z.reshape(-1, zsh[-1])
100+
original_vMF = super().decode(z)
101+
dec_mu = original_vMF.loc
102+
dec_k = original_vMF.scale
103+
104+
# Distance to the supp.
105+
alpha = self.similarity(self.min_distance(z)).unsqueeze(-1)
106+
reweighted_k = (1 - alpha) * dec_k + alpha * (torch.ones_like(dec_k) * self.limit_k)
107+
108+
return VonMisesFisher(dec_mu, reweighted_k)
109+
110+
def plot_latent_space(self):
111+
encodings = self.encodings.detach().numpy()
112+
enc_x, enc_y = encodings[:, 0], encodings[:, 1]
113+
114+
n_x, n_y = 300, 300
115+
x_lims = (enc_x.min(), enc_x.max())
116+
y_lims = (enc_y.min(), enc_y.max())
117+
z1 = torch.linspace(*x_lims, n_x)
118+
z2 = torch.linspace(*y_lims, n_x)
119+
120+
K = np.zeros((n_y, n_x))
121+
zs = torch.Tensor([[x, y] for x in z1 for y in z2])
122+
positions = {
123+
(x.item(), y.item()): (i, j) for j, x in enumerate(z1) for i, y in enumerate(reversed(z2))
124+
}
125+
vMF = self.decode(zs)
126+
ks = vMF.scale
127+
# _, ks = self.reweight(zs) # [b, 28, 28]
128+
ks = ks.detach().numpy()
129+
mean_ks = np.mean(ks, axis=1)
130+
for l, (x, y) in enumerate(zs):
131+
i, j = positions[(x.item(), y.item())]
132+
K[i, j] = mean_ks[l]
133+
134+
_, ax = plt.subplots(1, 1)
135+
ax.scatter(encodings[:, 0], encodings[:, 1], s=1)
136+
ax.scatter(self.cluster_centers[:, 0], self.cluster_centers[:, 1])
137+
plot = ax.imshow(K, extent=[*x_lims, *y_lims])
138+
plt.colorbar(plot, ax=ax)
139+
# plt.show()
140+
141+
142+
if __name__ == "__main__":
143+
# Loading the model.
144+
145+
n_hidden = 30
146+
bones = torch.tensor([1, 2, 3, 4, 6, 7, 8, 9, 13, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
147+
train_dataset, test_dataset, radii = load_bones_data(bones=bones)
148+
n_bones = len(bones)
149+
150+
vae_uq = VAE_motion_UQ(n_bones, n_hidden, radii)
151+
vae_uq.load_state_dict(
152+
torch.load("./examples/black_box_random_geometries/von_mises_fisher_example/models/motion.pt")
153+
)
154+
vae_uq.update_cluster_centers(train_dataset.tensors[0])
155+
vae_uq.plot_latent_space()
156+
plt.show()
157+
pass

examples/black_box_random_geometries/von_mises_fisher_example/vmf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ class VonMisesFisher(torch.distributions.Distribution):
5252

5353
@property
5454
def mean(self):
55-
return self.loc * (self.ive(self.__m / 2, self.scale) / self.ive(self.__m / 2 - 1, self.scale))
55+
return self.loc * (
56+
self.ive(self.__m / 2, self.scale.unsqueeze(-1))
57+
/ self.ive(self.__m / 2 - 1, self.scale.unsqueeze(-1))
58+
)
5659

5760
@property
5861
def stddev(self):

0 commit comments

Comments
 (0)