Skip to content

Commit 805189f

Browse files
committed
Add 'CVAE' model
1 parent 0453c74 commit 805189f

File tree

4 files changed

+86
-2
lines changed

4 files changed

+86
-2
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@
9898
### VAEs
9999

100100
- [x] [`VAE`](cvm/models/vae/vae.py) - [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114), 2013
101+
- [x] [`CVAE`](cvm/models/vae/cvae.py) - [Learning Structured Output Representation using Deep Conditional Generative Models
102+
](https://papers.nips.cc/paper/2015/hash/8d55a249e6baa5c06772297520da2051-Abstract.html), NeurIPS, 2015
101103
- [ ] `β-VAE` - [beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework](https://openreview.net/forum?id=Sy2fzU9gl), ICLR, 2017
102104

103105

cvm/models/vae/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .vae import *
1+
from .vae import *
2+
from .cvae import *

cvm/models/vae/cvae.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from ..utils import export, load_from_local_or_url
5+
from typing import Any
6+
7+
8+
@export
9+
class ConditionalVAE(nn.Module):
10+
"""
11+
Paper: [Learning Structured Output Representation using Deep Conditional Generative Models](https://papers.nips.cc/paper/2015/hash/8d55a249e6baa5c06772297520da2051-Abstract.html)
12+
"""
13+
def __init__(
14+
self,
15+
image_size,
16+
nz: int = 100,
17+
**kwargs: Any
18+
):
19+
super().__init__()
20+
21+
self.image_size = image_size
22+
self.nz = nz
23+
24+
self.embeds_en = nn.Embedding(10, 200)
25+
26+
self.embeds_de = nn.Embedding(10, 10)
27+
28+
# Q(z|X)
29+
self.encoder = nn.Sequential(
30+
nn.Linear(self.image_size ** 2 + 200, 512),
31+
nn.LeakyReLU(0.2, inplace=True),
32+
nn.Linear(512, 512),
33+
nn.LeakyReLU(0.2, inplace=True),
34+
nn.Linear(512, 256),
35+
nn.LeakyReLU(0.2, inplace=True),
36+
nn.Linear(256, self.nz * 2)
37+
)
38+
39+
# P(X|z)
40+
self.decoder = nn.Sequential(
41+
nn.Linear(self.nz + 10, 256),
42+
nn.LeakyReLU(0.2, inplace=True),
43+
nn.Linear(256, 512),
44+
nn.LeakyReLU(0.2, inplace=True),
45+
nn.Linear(512, 512),
46+
nn.LeakyReLU(0.2, inplace=True),
47+
nn.Linear(512, self.image_size ** 2),
48+
nn.Sigmoid(),
49+
nn.Unflatten(1, (1, image_size, image_size))
50+
)
51+
52+
def sample_z(self, mu, logvar, c):
53+
eps = torch.randn_like(logvar)
54+
55+
return torch.cat([mu + eps * torch.exp(0.5 * logvar), c], dim=1)
56+
57+
def forward(self, x, c):
58+
x = torch.flatten(x, 1)
59+
60+
x = torch.cat([x, self.embeds_en(c)], dim=1)
61+
62+
mu, logvar = torch.chunk(self.encoder(x), 2, dim=1)
63+
64+
z = self.sample_z(mu, logvar, self.embeds_de(c))
65+
66+
x = self.decoder(z)
67+
return x, mu, logvar
68+
69+
70+
@export
71+
def cvae(
72+
pretrained: bool = False,
73+
pth: str = None,
74+
progress: bool = True,
75+
**kwargs: Any
76+
):
77+
model = ConditionalVAE(**kwargs)
78+
79+
if pretrained:
80+
load_from_local_or_url(model, pth, kwargs.get('url', None), progress)
81+
return model

cvm/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.20'
1+
__version__ = '0.1.21'

0 commit comments

Comments
 (0)