|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | +from ..utils import export, load_from_local_or_url |
| 5 | +from typing import Any |
| 6 | + |
| 7 | + |
| 8 | +@export |
| 9 | +class ConditionalVAE(nn.Module): |
| 10 | + """ |
| 11 | + Paper: [Learning Structured Output Representation using Deep Conditional Generative Models](https://papers.nips.cc/paper/2015/hash/8d55a249e6baa5c06772297520da2051-Abstract.html) |
| 12 | + """ |
| 13 | + def __init__( |
| 14 | + self, |
| 15 | + image_size, |
| 16 | + nz: int = 100, |
| 17 | + **kwargs: Any |
| 18 | + ): |
| 19 | + super().__init__() |
| 20 | + |
| 21 | + self.image_size = image_size |
| 22 | + self.nz = nz |
| 23 | + |
| 24 | + self.embeds_en = nn.Embedding(10, 200) |
| 25 | + |
| 26 | + self.embeds_de = nn.Embedding(10, 10) |
| 27 | + |
| 28 | + # Q(z|X) |
| 29 | + self.encoder = nn.Sequential( |
| 30 | + nn.Linear(self.image_size ** 2 + 200, 512), |
| 31 | + nn.LeakyReLU(0.2, inplace=True), |
| 32 | + nn.Linear(512, 512), |
| 33 | + nn.LeakyReLU(0.2, inplace=True), |
| 34 | + nn.Linear(512, 256), |
| 35 | + nn.LeakyReLU(0.2, inplace=True), |
| 36 | + nn.Linear(256, self.nz * 2) |
| 37 | + ) |
| 38 | + |
| 39 | + # P(X|z) |
| 40 | + self.decoder = nn.Sequential( |
| 41 | + nn.Linear(self.nz + 10, 256), |
| 42 | + nn.LeakyReLU(0.2, inplace=True), |
| 43 | + nn.Linear(256, 512), |
| 44 | + nn.LeakyReLU(0.2, inplace=True), |
| 45 | + nn.Linear(512, 512), |
| 46 | + nn.LeakyReLU(0.2, inplace=True), |
| 47 | + nn.Linear(512, self.image_size ** 2), |
| 48 | + nn.Sigmoid(), |
| 49 | + nn.Unflatten(1, (1, image_size, image_size)) |
| 50 | + ) |
| 51 | + |
| 52 | + def sample_z(self, mu, logvar, c): |
| 53 | + eps = torch.randn_like(logvar) |
| 54 | + |
| 55 | + return torch.cat([mu + eps * torch.exp(0.5 * logvar), c], dim=1) |
| 56 | + |
| 57 | + def forward(self, x, c): |
| 58 | + x = torch.flatten(x, 1) |
| 59 | + |
| 60 | + x = torch.cat([x, self.embeds_en(c)], dim=1) |
| 61 | + |
| 62 | + mu, logvar = torch.chunk(self.encoder(x), 2, dim=1) |
| 63 | + |
| 64 | + z = self.sample_z(mu, logvar, self.embeds_de(c)) |
| 65 | + |
| 66 | + x = self.decoder(z) |
| 67 | + return x, mu, logvar |
| 68 | + |
| 69 | + |
| 70 | +@export |
| 71 | +def cvae( |
| 72 | + pretrained: bool = False, |
| 73 | + pth: str = None, |
| 74 | + progress: bool = True, |
| 75 | + **kwargs: Any |
| 76 | +): |
| 77 | + model = ConditionalVAE(**kwargs) |
| 78 | + |
| 79 | + if pretrained: |
| 80 | + load_from_local_or_url(model, pth, kwargs.get('url', None), progress) |
| 81 | + return model |
0 commit comments