forked from RohanS14/torchVAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpcvae.py
48 lines (36 loc) · 1.7 KB
/
pcvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Prediction-constrained VAE."""
import math
from models.vae import VariationalAutoencoder
from models.logreg import LogisticRegression
class PredictionConstrainedVAE(VariationalAutoencoder):
"""Variational Autoencoder with prediction constraint.
Args:
architecture (str): The architecture type of the VAE.
latent_dims (int): The number of latent dimensions.
num_classes (int): The number of classes for the classifier.
input_dims (int): The size of the input. Default is 784 for MNIST.
Attributes:
classifier (LogisticRegression): The classifier for the latent space.
"""
def __init__(
self, architecture="fc", latent_dims=20, num_classes=10, input_dims=784, distn="bern"
):
super(PredictionConstrainedVAE, self).__init__(
architecture, latent_dims, input_dims, distn
)
self.architecture = architecture
self.classifier = LogisticRegression(latent_dims, num_classes)
def forward(self, x):
"""Encodes and decodes the image. Classifies the input x based on its latent representation z."""
mu_z, sigma_z = self.encoder(x)
z = self.encoder.sample(mu_z, sigma_z)
if self.architecture == "fc" and self.decoder.distn == "bern":
image_size = math.isqrt(self.decoder.output_dims)
probs_xhat = self.decoder(z).reshape((-1, 1, image_size, image_size))
xhat = self.decoder.sample(probs_xhat)
else:
mu_x, sigma_x = self.decoder(z)
xhat = self.decoder.sample(mu_x, sigma_x)
# make classifications based on latent variable
logits = self.classifier(z)
return mu_z, sigma_z, xhat, logits