Skip to content

Commit

Permalink
Added more tests to distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentStimper committed Jan 22, 2023
1 parent 921a703 commit 014ac07
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 10 deletions.
2 changes: 1 addition & 1 deletion normflows/distributions/decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class DecoderTest(unittest.TestCase):

def test_normalizing_flow_vae(self):
def test_decoder(self):
batch_size = 5
n_dim = 10
n_bottleneck = 3
Expand Down
1 change: 0 additions & 1 deletion normflows/distributions/distribution_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest
import torch

from torch.testing import assert_close

Expand Down
8 changes: 4 additions & 4 deletions normflows/distributions/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def __init__(self, zmin=0.0, zmax=1.0):
super().__init__()
self.zmin = zmin
self.zmax = zmax
self.log_q = -torch.log(zmax - zmin)
self.log_q = -np.log(zmax - zmin)

def forward(self, x, num_samples=1):
z = (
x.unsqueeze(1)
.repeat(1, num_samples, 1)
.uniform_(min=self.zmin, max=self.zmax)
.uniform_(self.zmin, self.zmax)
)
log_q = torch.zeros(z.size()[0:2]).fill_(self.log_q)
return z, log_q
Expand All @@ -82,9 +82,9 @@ def __init__(self, loc, scale):
super().__init__()
self.d = len(loc)
if not torch.is_tensor(loc):
loc = torch.tensor(loc).float()
loc = torch.tensor(loc)
if not torch.is_tensor(scale):
scale = torch.tensor(scale).float()
scale = torch.tensor(scale)
self.loc = nn.Parameter(loc.reshape((1, 1, self.d)))
self.scale = nn.Parameter(scale)

Expand Down
99 changes: 99 additions & 0 deletions normflows/distributions/encoder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import unittest
import torch
import numpy as np

from torch.testing import assert_close

from normflows.nets import MLP
from normflows.distributions.encoder import Dirac, Uniform, \
ConstDiagGaussian, NNDiagGaussian


class EncoderTest(unittest.TestCase):

def checkForward(self, encoder, inputs, num_samples=1):
# Do forward
outputs, log_p = encoder(inputs, num_samples)
# Check type
assert log_p.dtype == inputs.dtype
assert outputs.dtype == inputs.dtype
# Check shape
assert log_p.shape[0] == inputs.shape[0]
assert outputs.shape[0] == inputs.shape[0]
assert log_p.shape[1] == num_samples
assert outputs.shape[1] == num_samples
# Check dim
assert outputs.dim() > log_p.dim()
# Return results
return outputs, log_p

def checkLogProb(self, encoder, inputs_z, inputs_x):
# Compute log prob
log_p = encoder.log_prob(inputs_z, inputs_x)
# Check type
assert log_p.dtype == inputs_z.dtype
# Check shape
assert log_p.shape[0] == inputs_z.shape[0]
# Return results
return log_p

def checkForwardLogProb(self, encoder, inputs, num_samples=1,
atol=None, rtol=None):
# Check forward
outputs, log_p = self.checkForward(encoder, inputs,
num_samples)
# Check log prob
log_p_ = self.checkLogProb(encoder, outputs, inputs)
# Check consistency
assert_close(log_p_, log_p, atol=atol, rtol=rtol)

def test_dirac_uniform(self):
batch_size = 5
encoder_class = [Dirac, Uniform]
params = [(2, 1), (1, 3), (2, 2)]

# Test model
for n_dim, num_samples in params:
for encoder_c in encoder_class:
with self.subTest(n_dim=n_dim, num_samples=num_samples,
encoder_c=encoder_c):
# Set up encoder
encoder = encoder_c()
# Do tests
inputs = torch.rand(batch_size, n_dim)
self.checkForwardLogProb(encoder, inputs, num_samples)

def test_const_diag_gaussian(self):
batch_size = 5
params = [(2, 1), (1, 3), (2, 2)]

# Test model
for n_dim, num_samples in params:
with self.subTest(n_dim=n_dim, num_samples=num_samples):
# Set up encoder
loc = np.random.randn(n_dim).astype(np.float32)
scale = np.random.rand(n_dim).astype(np.float32) * 0.1 + 1
encoder = ConstDiagGaussian(loc, scale)
# Do tests
inputs = torch.rand(batch_size, n_dim)
self.checkForwardLogProb(encoder, inputs, num_samples)

def test_nn_diag_gaussian(self):
batch_size = 5
n_hidden = 8
params = [(4, 2, 1), (1, 1, 3), (5, 3, 2)]

# Test model
for n_dim, n_latent, num_samples in params:
with self.subTest(n_dim=n_dim, n_latent=n_latent,
num_samples=num_samples):
# Set up encoder
nn = MLP([n_dim, n_hidden, n_latent * 2])
encoder = NNDiagGaussian(nn)
# Do tests
inputs = torch.rand(batch_size, n_dim)
self.checkForwardLogProb(encoder, inputs, num_samples)


if __name__ == "__main__":
unittest.main()
13 changes: 9 additions & 4 deletions normflows/distributions/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,16 @@ def log_prob(self, z):
class Sinusoidal(PriorDistribution):
def __init__(self, scale, period):
"""Distribution 2d with sinusoidal density
given by
```
log(p) = - 1/2 * ((z[1] - w_1(z)) / (2 * scale)) ** 2
w_1(z) = sin(2*pi / period * z[0])
```
Args:
loc: distance of modes from the origin
scale: scale of modes
scale: scale of the distribution, see formula
period: period of the sinosoidal
"""
self.scale = scale
self.period = period
Expand Down Expand Up @@ -276,8 +282,7 @@ def __init__(self, scale):
"""Distribution 2d of a smiley :)
Args:
loc: distance of modes from the origin
scale: scale of modes
scale: scale of the smiley
"""
self.scale = scale
self.loc = 2.0
Expand Down
45 changes: 45 additions & 0 deletions normflows/distributions/prior_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import unittest
import torch
import numpy as np

from normflows.distributions.prior import ImagePrior, TwoModes, \
Sinusoidal, Sinusoidal_split, Sinusoidal_gap, Smiley


class PriorTest(unittest.TestCase):
def test_2d_priors(self):
batch_size = 5
prior_params = [(TwoModes, {'loc': 1, 'scale': 0.2}),
(Sinusoidal, {'scale': 1., 'period': 1.}),
(Sinusoidal_split, {'scale': 1., 'period': 1.}),
(Sinusoidal_gap, {'scale': 1., 'period': 1.}),
(Smiley, {'scale': 1.})]

# Test model
for prior_c, params in prior_params:
with self.subTest(prior_c=prior_c, params=params):
# Set up prior
prior = prior_c(**params)
# Test prior
inputs = torch.randn(batch_size, 2)
log_p = prior.log_prob(inputs)
assert log_p.shape == (batch_size,)
assert log_p.dtype == inputs.dtype

def test_image_prior(self):
# Set up prior
image = np.random.rand(10, 10).astype(np.float32)
prior = ImagePrior(image)
for num_samples in [1, 5]:
with self.subTest(num_samples=num_samples):
# Test prior
samples = prior.sample(num_samples)
assert samples.shape == (num_samples, 2)
assert samples.dtype == torch.float32
log_p = prior.log_prob(samples)
assert log_p.shape == (num_samples,)
assert log_p.dtype == torch.float32


if __name__ == "__main__":
unittest.main()
26 changes: 26 additions & 0 deletions normflows/distributions/target_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import unittest
import torch

from normflows.distributions.target import TwoMoons, \
CircularGaussianMixture, RingMixture


class TargetTest(unittest.TestCase):
def test_targets(self):
targets = [TwoMoons, CircularGaussianMixture,
RingMixture]
for num_samples in [1, 5]:
for target_ in targets:
with self.subTest(num_samples=num_samples,
target_=target_):
# Set up prior
target = target_()
# Test prior
samples = target.sample(num_samples)
assert samples.shape == (num_samples, 2)
log_p = target.log_prob(samples)
assert log_p.shape == (num_samples,)


if __name__ == "__main__":
unittest.main()

0 comments on commit 014ac07

Please sign in to comment.