|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from sklearn.cluster import KMeans |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +from stochman.manifold import StatisticalManifold |
| 8 | + |
| 9 | +from vae_motion import VAE_Motion |
| 10 | +from vmf import VonMisesFisher |
| 11 | +from data_utils import load_bones_data |
| 12 | + |
| 13 | + |
| 14 | +class TranslatedSigmoid(nn.Module): |
| 15 | + """ |
| 16 | + A translated sigmoid function that is used |
| 17 | + to regulate entropy in entropy networks. |
| 18 | +
|
| 19 | + Input: |
| 20 | + - beta: float (1.5 by default). The lower, |
| 21 | + the narrower the region of low entropy. |
| 22 | + Negative values give nice latent spaces |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, beta: float = -1.5) -> None: |
| 26 | + super(TranslatedSigmoid, self).__init__() |
| 27 | + self.beta = nn.Parameter(torch.tensor([beta])) |
| 28 | + |
| 29 | + def forward(self, x): |
| 30 | + beta = torch.nn.functional.softplus(self.beta) |
| 31 | + alpha = -beta * (6.9077542789816375) |
| 32 | + val = torch.sigmoid((x + alpha) / beta) |
| 33 | + |
| 34 | + return val |
| 35 | + |
| 36 | + |
| 37 | +class VAE_motion_UQ(VAE_Motion, StatisticalManifold): |
| 38 | + def __init__(self, n_bones: int, n_hidden: int, radii: torch.Tensor) -> None: |
| 39 | + super().__init__(n_bones, n_hidden, radii) |
| 40 | + |
| 41 | + self.n_clusters = 500 |
| 42 | + self.beta = -18.5 |
| 43 | + self.limit_k = 0.1 |
| 44 | + |
| 45 | + self.translated_sigmoid = None |
| 46 | + self.encodings = None |
| 47 | + self.cluster_centers = None |
| 48 | + |
| 49 | + def update_cluster_centers(self, training_data: torch.Tensor): |
| 50 | + """ |
| 51 | + Encodes the training data and |
| 52 | + runs KMeans to get centers. |
| 53 | + """ |
| 54 | + self.translated_sigmoid = TranslatedSigmoid(self.beta) |
| 55 | + |
| 56 | + batch_size, n_bones, _ = training_data.shape |
| 57 | + flat_training_data = training_data.reshape(batch_size, n_bones * 3) |
| 58 | + self.encodings = self.encode(flat_training_data).mean |
| 59 | + Z = self.encodings.detach().numpy() |
| 60 | + kmeans = KMeans(n_clusters=self.n_clusters) |
| 61 | + kmeans.fit(Z) |
| 62 | + self.cluster_centers = torch.from_numpy(kmeans.cluster_centers_) |
| 63 | + |
| 64 | + def min_distance(self, z: torch.Tensor) -> torch.Tensor: |
| 65 | + """ |
| 66 | + V(z) in the notation of the paper |
| 67 | +
|
| 68 | + TODO: clean or move to utilities |
| 69 | + """ |
| 70 | + # What's the size of z? |
| 71 | + # |z| = (batch, zdim), right? |
| 72 | + zsh = z.shape |
| 73 | + z = z.view(-1, z.shape[-1]) # Nx(zdim) |
| 74 | + |
| 75 | + z_norm = (z ** 2).sum(1, keepdim=True) # Nx1 |
| 76 | + center_norm = (self.cluster_centers ** 2).sum(1).view(1, -1) # 1x(num_clusters) |
| 77 | + d2 = ( |
| 78 | + z_norm + center_norm - 2.0 * torch.mm(z, self.cluster_centers.transpose(0, 1)) |
| 79 | + ) # Nx(num_clusters) |
| 80 | + d2.clamp_(min=0.0) # Nx(num_clusters) |
| 81 | + min_dist, _ = d2.min(dim=1) # N |
| 82 | + return min_dist.view(zsh[:-1]) |
| 83 | + |
| 84 | + def similarity(self, v: torch.Tensor) -> torch.Tensor: |
| 85 | + """ |
| 86 | + T(z) or alpha in the notation of the paper, but |
| 87 | + backwards and translated to make it 0 at 0 and |
| 88 | + 1 and infty. |
| 89 | + """ |
| 90 | + return self.translated_sigmoid(v) |
| 91 | + |
| 92 | + def decode(self, z: torch.Tensor) -> VonMisesFisher: |
| 93 | + """ |
| 94 | + An alternate version of the decoder that pushes |
| 95 | + k to self.limit_k away of the support of the |
| 96 | + data. |
| 97 | + """ |
| 98 | + zsh = z.shape |
| 99 | + z = z.reshape(-1, zsh[-1]) |
| 100 | + original_vMF = super().decode(z) |
| 101 | + dec_mu = original_vMF.loc |
| 102 | + dec_k = original_vMF.scale |
| 103 | + |
| 104 | + # Distance to the supp. |
| 105 | + alpha = self.similarity(self.min_distance(z)).unsqueeze(-1) |
| 106 | + reweighted_k = (1 - alpha) * dec_k + alpha * (torch.ones_like(dec_k) * self.limit_k) |
| 107 | + |
| 108 | + return VonMisesFisher(dec_mu, reweighted_k) |
| 109 | + |
| 110 | + def plot_latent_space(self): |
| 111 | + encodings = self.encodings.detach().numpy() |
| 112 | + enc_x, enc_y = encodings[:, 0], encodings[:, 1] |
| 113 | + |
| 114 | + n_x, n_y = 300, 300 |
| 115 | + x_lims = (enc_x.min(), enc_x.max()) |
| 116 | + y_lims = (enc_y.min(), enc_y.max()) |
| 117 | + z1 = torch.linspace(*x_lims, n_x) |
| 118 | + z2 = torch.linspace(*y_lims, n_x) |
| 119 | + |
| 120 | + K = np.zeros((n_y, n_x)) |
| 121 | + zs = torch.Tensor([[x, y] for x in z1 for y in z2]) |
| 122 | + positions = { |
| 123 | + (x.item(), y.item()): (i, j) for j, x in enumerate(z1) for i, y in enumerate(reversed(z2)) |
| 124 | + } |
| 125 | + vMF = self.decode(zs) |
| 126 | + ks = vMF.scale |
| 127 | + # _, ks = self.reweight(zs) # [b, 28, 28] |
| 128 | + ks = ks.detach().numpy() |
| 129 | + mean_ks = np.mean(ks, axis=1) |
| 130 | + for l, (x, y) in enumerate(zs): |
| 131 | + i, j = positions[(x.item(), y.item())] |
| 132 | + K[i, j] = mean_ks[l] |
| 133 | + |
| 134 | + _, ax = plt.subplots(1, 1) |
| 135 | + ax.scatter(encodings[:, 0], encodings[:, 1], s=1) |
| 136 | + ax.scatter(self.cluster_centers[:, 0], self.cluster_centers[:, 1]) |
| 137 | + plot = ax.imshow(K, extent=[*x_lims, *y_lims]) |
| 138 | + plt.colorbar(plot, ax=ax) |
| 139 | + # plt.show() |
| 140 | + |
| 141 | + |
| 142 | +if __name__ == "__main__": |
| 143 | + # Loading the model. |
| 144 | + |
| 145 | + n_hidden = 30 |
| 146 | + bones = torch.tensor([1, 2, 3, 4, 6, 7, 8, 9, 13, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]) |
| 147 | + train_dataset, test_dataset, radii = load_bones_data(bones=bones) |
| 148 | + n_bones = len(bones) |
| 149 | + |
| 150 | + vae_uq = VAE_motion_UQ(n_bones, n_hidden, radii) |
| 151 | + vae_uq.load_state_dict( |
| 152 | + torch.load("./examples/black_box_random_geometries/von_mises_fisher_example/models/motion.pt") |
| 153 | + ) |
| 154 | + vae_uq.update_cluster_centers(train_dataset.tensors[0]) |
| 155 | + vae_uq.plot_latent_space() |
| 156 | + plt.show() |
| 157 | + pass |
0 commit comments