-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
swae.py
206 lines (167 loc) · 7.17 KB
/
swae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import torch
from models import BaseVAE
from torch import nn
from torch.nn import functional as F
from torch import distributions as dist
from .types_ import *
class SWAE(BaseVAE):
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
reg_weight: int = 100,
wasserstein_deg: float= 2.,
num_projections: int = 50,
projection_dist: str = 'normal',
**kwargs) -> None:
super(SWAE, self).__init__()
self.latent_dim = latent_dim
self.reg_weight = reg_weight
self.p = wasserstein_deg
self.num_projections = num_projections
self.proj_dist = projection_dist
modules = []
if hidden_dims is None:
hidden_dims = [32, 64, 128, 256, 512]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_z = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
def encode(self, input: Tensor) -> Tensor:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:param input: (Tensor) Input tensor to encoder [N x C x H x W]
:return: (Tensor) List of latent codes
"""
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
z = self.fc_z(result)
return z
def decode(self, z: Tensor) -> Tensor:
result = self.decoder_input(z)
result = result.view(-1, 512, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
z = self.encode(input)
return [self.decode(z), input, z]
def loss_function(self,
*args,
**kwargs) -> dict:
recons = args[0]
input = args[1]
z = args[2]
batch_size = input.size(0)
bias_corr = batch_size * (batch_size - 1)
reg_weight = self.reg_weight / bias_corr
recons_loss_l2 = F.mse_loss(recons, input)
recons_loss_l1 = F.l1_loss(recons, input)
swd_loss = self.compute_swd(z, self.p, reg_weight)
loss = recons_loss_l2 + recons_loss_l1 + swd_loss
return {'loss': loss, 'Reconstruction_Loss':(recons_loss_l2 + recons_loss_l1), 'SWD': swd_loss}
def get_random_projections(self, latent_dim: int, num_samples: int) -> Tensor:
"""
Returns random samples from latent distribution's (Gaussian)
unit sphere for projecting the encoded samples and the
distribution samples.
:param latent_dim: (Int) Dimensionality of the latent space (D)
:param num_samples: (Int) Number of samples required (S)
:return: Random projections from the latent unit sphere
"""
if self.proj_dist == 'normal':
rand_samples = torch.randn(num_samples, latent_dim)
elif self.proj_dist == 'cauchy':
rand_samples = dist.Cauchy(torch.tensor([0.0]),
torch.tensor([1.0])).sample((num_samples, latent_dim)).squeeze()
else:
raise ValueError('Unknown projection distribution.')
rand_proj = rand_samples / rand_samples.norm(dim=1).view(-1,1)
return rand_proj # [S x D]
def compute_swd(self,
z: Tensor,
p: float,
reg_weight: float) -> Tensor:
"""
Computes the Sliced Wasserstein Distance (SWD) - which consists of
randomly projecting the encoded and prior vectors and computing
their Wasserstein distance along those projections.
:param z: Latent samples # [N x D]
:param p: Value for the p^th Wasserstein distance
:param reg_weight:
:return:
"""
prior_z = torch.randn_like(z) # [N x D]
device = z.device
proj_matrix = self.get_random_projections(self.latent_dim,
num_samples=self.num_projections).transpose(0,1).to(device)
latent_projections = z.matmul(proj_matrix) # [N x S]
prior_projections = prior_z.matmul(proj_matrix) # [N x S]
# The Wasserstein distance is computed by sorting the two projections
# across the batches and computing their element-wise l2 distance
w_dist = torch.sort(latent_projections.t(), dim=1)[0] - \
torch.sort(prior_projections.t(), dim=1)[0]
w_dist = w_dist.pow(p)
return reg_weight * w_dist.mean()
def sample(self,
num_samples:int,
current_device: int, **kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples,
self.latent_dim)
z = z.to(current_device)
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]