@@ -95,7 +95,7 @@ def reweight(self, z: torch.Tensor) -> torch.Tensor:
95
95
"""
96
96
zsh = z .shape
97
97
z = z .reshape (- 1 , zsh [- 1 ])
98
- dec_mu , dec_k = self .decode (z ) # Nx(num_bones)x3, Nx(num_bones)
98
+ dec_mu , dec_k = super () .decode (z ) # Nx(num_bones)x3, Nx(num_bones)
99
99
100
100
# Distance to the supp.
101
101
alpha = self .similarity (self .min_distance (z )).unsqueeze (- 1 )
@@ -106,6 +106,14 @@ def reweight(self, z: torch.Tensor) -> torch.Tensor:
106
106
ksh = dec_k .shape
107
107
return dec_mu .view (zsh [:- 1 ] + mush [1 :]), reweighted_k .view (zsh [:- 1 ] + ksh [1 :])
108
108
109
+ def decode (self , z , reweight = True ):
110
+ if reweight :
111
+ mu , k = self .reweight (z )
112
+ else :
113
+ mu , k = super ().decode (z )
114
+
115
+ return VonMisesFisher (loc = mu , scale = k )
116
+
109
117
def forward (self , x ):
110
118
"""
111
119
Forward pass through the network, w. extrapolation
@@ -130,7 +138,7 @@ def forward(self, x):
130
138
131
139
return x , p_mu , p_k .unsqueeze (2 ), q_mu , q_var
132
140
133
- def plot_latent_space (self ):
141
+ def plot_latent_space (self , ax = None ):
134
142
encodings = self .encodings .detach ().numpy ()
135
143
enc_x , enc_y = encodings [:, 0 ], encodings [:, 1 ]
136
144
@@ -152,7 +160,9 @@ def plot_latent_space(self):
152
160
i , j = positions [(x .item (), y .item ())]
153
161
K [i , j ] = mean_ks [l ]
154
162
155
- _ , ax = plt .subplots (1 , 1 )
163
+ if ax is None :
164
+ _ , ax = plt .subplots (1 , 1 )
165
+
156
166
ax .scatter (encodings [:, 0 ], encodings [:, 1 ], s = 1 )
157
167
# ax.scatter(self.cluster_centers[:, 0], self.cluster_centers[:, 1])
158
168
plot = ax .imshow (K , extent = [* x_lims , * y_lims ])
@@ -173,6 +183,13 @@ def plot_latent_space(self):
173
183
torch .load ("./examples/black_box_random_geometries/von_mises_fisher_example/models/motion_2.pt" )
174
184
)
175
185
vae_uq .update_cluster_centers ()
176
- vae_uq .plot_latent_space ()
186
+ vae_manifold = StatisticalManifold (vae_uq )
187
+
188
+ _ , ax = plt .subplots (1 , 1 , figsize = (7 , 7 ))
189
+ vae_uq .plot_latent_space (ax = ax )
190
+ for _ in range (10 ):
191
+ idx_1 , idx_2 = np .random .randint (0 , len (vae_uq .encodings ), size = (2 ,))
192
+ geodesic , _ = vae_manifold .connecting_geodesic (vae_uq .encodings [idx_1 ], vae_uq .encodings [idx_2 ])
193
+ geodesic .plot (ax = ax )
194
+
177
195
plt .show ()
178
- pass
0 commit comments