forked from VincentStimper/normalizing-flows
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
921a703
commit 014ac07
Showing
7 changed files
with
184 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
import unittest | ||
import torch | ||
|
||
from torch.testing import assert_close | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import unittest | ||
import torch | ||
import numpy as np | ||
|
||
from torch.testing import assert_close | ||
|
||
from normflows.nets import MLP | ||
from normflows.distributions.encoder import Dirac, Uniform, \ | ||
ConstDiagGaussian, NNDiagGaussian | ||
|
||
|
||
class EncoderTest(unittest.TestCase): | ||
|
||
def checkForward(self, encoder, inputs, num_samples=1): | ||
# Do forward | ||
outputs, log_p = encoder(inputs, num_samples) | ||
# Check type | ||
assert log_p.dtype == inputs.dtype | ||
assert outputs.dtype == inputs.dtype | ||
# Check shape | ||
assert log_p.shape[0] == inputs.shape[0] | ||
assert outputs.shape[0] == inputs.shape[0] | ||
assert log_p.shape[1] == num_samples | ||
assert outputs.shape[1] == num_samples | ||
# Check dim | ||
assert outputs.dim() > log_p.dim() | ||
# Return results | ||
return outputs, log_p | ||
|
||
def checkLogProb(self, encoder, inputs_z, inputs_x): | ||
# Compute log prob | ||
log_p = encoder.log_prob(inputs_z, inputs_x) | ||
# Check type | ||
assert log_p.dtype == inputs_z.dtype | ||
# Check shape | ||
assert log_p.shape[0] == inputs_z.shape[0] | ||
# Return results | ||
return log_p | ||
|
||
def checkForwardLogProb(self, encoder, inputs, num_samples=1, | ||
atol=None, rtol=None): | ||
# Check forward | ||
outputs, log_p = self.checkForward(encoder, inputs, | ||
num_samples) | ||
# Check log prob | ||
log_p_ = self.checkLogProb(encoder, outputs, inputs) | ||
# Check consistency | ||
assert_close(log_p_, log_p, atol=atol, rtol=rtol) | ||
|
||
def test_dirac_uniform(self): | ||
batch_size = 5 | ||
encoder_class = [Dirac, Uniform] | ||
params = [(2, 1), (1, 3), (2, 2)] | ||
|
||
# Test model | ||
for n_dim, num_samples in params: | ||
for encoder_c in encoder_class: | ||
with self.subTest(n_dim=n_dim, num_samples=num_samples, | ||
encoder_c=encoder_c): | ||
# Set up encoder | ||
encoder = encoder_c() | ||
# Do tests | ||
inputs = torch.rand(batch_size, n_dim) | ||
self.checkForwardLogProb(encoder, inputs, num_samples) | ||
|
||
def test_const_diag_gaussian(self): | ||
batch_size = 5 | ||
params = [(2, 1), (1, 3), (2, 2)] | ||
|
||
# Test model | ||
for n_dim, num_samples in params: | ||
with self.subTest(n_dim=n_dim, num_samples=num_samples): | ||
# Set up encoder | ||
loc = np.random.randn(n_dim).astype(np.float32) | ||
scale = np.random.rand(n_dim).astype(np.float32) * 0.1 + 1 | ||
encoder = ConstDiagGaussian(loc, scale) | ||
# Do tests | ||
inputs = torch.rand(batch_size, n_dim) | ||
self.checkForwardLogProb(encoder, inputs, num_samples) | ||
|
||
def test_nn_diag_gaussian(self): | ||
batch_size = 5 | ||
n_hidden = 8 | ||
params = [(4, 2, 1), (1, 1, 3), (5, 3, 2)] | ||
|
||
# Test model | ||
for n_dim, n_latent, num_samples in params: | ||
with self.subTest(n_dim=n_dim, n_latent=n_latent, | ||
num_samples=num_samples): | ||
# Set up encoder | ||
nn = MLP([n_dim, n_hidden, n_latent * 2]) | ||
encoder = NNDiagGaussian(nn) | ||
# Do tests | ||
inputs = torch.rand(batch_size, n_dim) | ||
self.checkForwardLogProb(encoder, inputs, num_samples) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import unittest | ||
import torch | ||
import numpy as np | ||
|
||
from normflows.distributions.prior import ImagePrior, TwoModes, \ | ||
Sinusoidal, Sinusoidal_split, Sinusoidal_gap, Smiley | ||
|
||
|
||
class PriorTest(unittest.TestCase): | ||
def test_2d_priors(self): | ||
batch_size = 5 | ||
prior_params = [(TwoModes, {'loc': 1, 'scale': 0.2}), | ||
(Sinusoidal, {'scale': 1., 'period': 1.}), | ||
(Sinusoidal_split, {'scale': 1., 'period': 1.}), | ||
(Sinusoidal_gap, {'scale': 1., 'period': 1.}), | ||
(Smiley, {'scale': 1.})] | ||
|
||
# Test model | ||
for prior_c, params in prior_params: | ||
with self.subTest(prior_c=prior_c, params=params): | ||
# Set up prior | ||
prior = prior_c(**params) | ||
# Test prior | ||
inputs = torch.randn(batch_size, 2) | ||
log_p = prior.log_prob(inputs) | ||
assert log_p.shape == (batch_size,) | ||
assert log_p.dtype == inputs.dtype | ||
|
||
def test_image_prior(self): | ||
# Set up prior | ||
image = np.random.rand(10, 10).astype(np.float32) | ||
prior = ImagePrior(image) | ||
for num_samples in [1, 5]: | ||
with self.subTest(num_samples=num_samples): | ||
# Test prior | ||
samples = prior.sample(num_samples) | ||
assert samples.shape == (num_samples, 2) | ||
assert samples.dtype == torch.float32 | ||
log_p = prior.log_prob(samples) | ||
assert log_p.shape == (num_samples,) | ||
assert log_p.dtype == torch.float32 | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import unittest | ||
import torch | ||
|
||
from normflows.distributions.target import TwoMoons, \ | ||
CircularGaussianMixture, RingMixture | ||
|
||
|
||
class TargetTest(unittest.TestCase): | ||
def test_targets(self): | ||
targets = [TwoMoons, CircularGaussianMixture, | ||
RingMixture] | ||
for num_samples in [1, 5]: | ||
for target_ in targets: | ||
with self.subTest(num_samples=num_samples, | ||
target_=target_): | ||
# Set up prior | ||
target = target_() | ||
# Test prior | ||
samples = target.sample(num_samples) | ||
assert samples.shape == (num_samples, 2) | ||
log_p = target.log_prob(samples) | ||
assert log_p.shape == (num_samples,) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |