Skip to content

Commit 6f92c65

Browse files
committed
Gets it to computing geodesics
1 parent e5e96c1 commit 6f92c65

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

examples/black_box_random_geometries/von_mises_fisher_example/vae_w_regularized_uncertainty.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def reweight(self, z: torch.Tensor) -> torch.Tensor:
9595
"""
9696
zsh = z.shape
9797
z = z.reshape(-1, zsh[-1])
98-
dec_mu, dec_k = self.decode(z) # Nx(num_bones)x3, Nx(num_bones)
98+
dec_mu, dec_k = super().decode(z) # Nx(num_bones)x3, Nx(num_bones)
9999

100100
# Distance to the supp.
101101
alpha = self.similarity(self.min_distance(z)).unsqueeze(-1)
@@ -106,6 +106,14 @@ def reweight(self, z: torch.Tensor) -> torch.Tensor:
106106
ksh = dec_k.shape
107107
return dec_mu.view(zsh[:-1] + mush[1:]), reweighted_k.view(zsh[:-1] + ksh[1:])
108108

109+
def decode(self, z, reweight=True):
110+
if reweight:
111+
mu, k = self.reweight(z)
112+
else:
113+
mu, k = super().decode(z)
114+
115+
return VonMisesFisher(loc=mu, scale=k)
116+
109117
def forward(self, x):
110118
"""
111119
Forward pass through the network, w. extrapolation
@@ -130,7 +138,7 @@ def forward(self, x):
130138

131139
return x, p_mu, p_k.unsqueeze(2), q_mu, q_var
132140

133-
def plot_latent_space(self):
141+
def plot_latent_space(self, ax=None):
134142
encodings = self.encodings.detach().numpy()
135143
enc_x, enc_y = encodings[:, 0], encodings[:, 1]
136144

@@ -152,7 +160,9 @@ def plot_latent_space(self):
152160
i, j = positions[(x.item(), y.item())]
153161
K[i, j] = mean_ks[l]
154162

155-
_, ax = plt.subplots(1, 1)
163+
if ax is None:
164+
_, ax = plt.subplots(1, 1)
165+
156166
ax.scatter(encodings[:, 0], encodings[:, 1], s=1)
157167
# ax.scatter(self.cluster_centers[:, 0], self.cluster_centers[:, 1])
158168
plot = ax.imshow(K, extent=[*x_lims, *y_lims])
@@ -173,6 +183,13 @@ def plot_latent_space(self):
173183
torch.load("./examples/black_box_random_geometries/von_mises_fisher_example/models/motion_2.pt")
174184
)
175185
vae_uq.update_cluster_centers()
176-
vae_uq.plot_latent_space()
186+
vae_manifold = StatisticalManifold(vae_uq)
187+
188+
_, ax = plt.subplots(1, 1, figsize=(7, 7))
189+
vae_uq.plot_latent_space(ax=ax)
190+
for _ in range(10):
191+
idx_1, idx_2 = np.random.randint(0, len(vae_uq.encodings), size=(2,))
192+
geodesic, _ = vae_manifold.connecting_geodesic(vae_uq.encodings[idx_1], vae_uq.encodings[idx_2])
193+
geodesic.plot(ax=ax)
194+
177195
plt.show()
178-
pass

examples/black_box_random_geometries/von_mises_fisher_example/vmf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,16 @@ def log_prob(self, x):
196196
return self._log_unnormalized_prob(x) - self._log_normalization()
197197

198198
def _log_unnormalized_prob(self, x):
199-
output = self.scale * (self.loc * x).sum(-1, keepdim=True)
199+
output = self.scale.unsqueeze(-1) * (self.loc * x).sum(-1, keepdim=True)
200200

201201
return output.view(*(output.shape[:-1]))
202202

203203
def _log_normalization(self):
204+
scale = self.scale.unsqueeze(-1)
204205
output = -(
205-
(self.__m / 2 - 1) * torch.log(self.scale)
206+
(self.__m / 2 - 1) * torch.log(scale)
206207
- (self.__m / 2) * math.log(2 * math.pi)
207-
- (self.scale + torch.log(self.ive(self.__m / 2 - 1, self.scale)))
208+
- (scale + torch.log(self.ive(self.__m / 2 - 1, scale)))
208209
)
209210

210211
return output.view(*(output.shape[:-1]))

0 commit comments

Comments
 (0)