forked from VincentStimper/normalizing-flows
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added tests to two base distributions
- Loading branch information
1 parent
f02fce3
commit 8ce961a
Showing
3 changed files
with
100 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import unittest | ||
import torch | ||
import numpy as np | ||
|
||
from normflows.distributions.base import DiagGaussian, UniformGaussian | ||
from normflows.distributions.distribution_test import DistributionTest | ||
|
||
|
||
class BaseTest(DistributionTest): | ||
def test_diag_gaussian(self): | ||
for shape in [1, (3,), [2, 3]]: | ||
for num_samples in [1, 3]: | ||
with self.subTest(shape=shape, num_samples=num_samples): | ||
distribution = DiagGaussian(shape) | ||
self.checkForwardLogProb(distribution, num_samples) | ||
_ = self.checkSample(distribution, num_samples) | ||
|
||
def test_uniform_gaussian(self): | ||
params = [(2, 1, None), (2, (0,), 0.5 * torch.ones(2)), | ||
(4, [2], None), (3, [2, 0], np.pi * torch.ones(3))] | ||
for ndim, ind, scale in params: | ||
for num_samples in [1, 3]: | ||
with self.subTest(ndim=ndim, ind=ind, scale=scale, | ||
num_samples=num_samples): | ||
distribution = UniformGaussian(ndim, ind, scale) | ||
self.checkForwardLogProb(distribution, num_samples) | ||
_ = self.checkSample(distribution, num_samples) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import unittest | ||
import torch | ||
|
||
from torch.testing import assert_close | ||
|
||
|
||
class DistributionTest(unittest.TestCase): | ||
""" | ||
Generic test case for distribution modules | ||
""" | ||
def assertClose(self, actual, expected, atol=None, rtol=None): | ||
assert_close(actual, expected, atol=atol, rtol=rtol) | ||
|
||
def checkForward(self, distribution, num_samples=1, **kwargs): | ||
# Do forward | ||
outputs, log_p = distribution(num_samples, **kwargs) | ||
# Check type | ||
assert outputs.dtype == log_p.dtype | ||
# Check shape | ||
assert log_p.shape[0] == num_samples | ||
assert outputs.shape[0] == num_samples | ||
# Check dim | ||
assert outputs.dim() > log_p.dim() | ||
# Return results | ||
return outputs, log_p | ||
|
||
def checkLogProb(self, distribution, inputs, **kwargs): | ||
# Compute log prob | ||
log_p = distribution.log_prob(inputs, **kwargs) | ||
# Check type | ||
assert log_p.dtype == inputs.dtype | ||
# Check shape | ||
assert log_p.shape[0] == inputs.shape[0] | ||
# Return results | ||
return log_p | ||
|
||
def checkSample(self, distribution, num_samples=1): | ||
# Do forward | ||
outputs = distribution.sample(num_samples) | ||
# Check shape | ||
assert outputs.shape[0] == num_samples | ||
# Check dim | ||
assert outputs.dim() > 1 | ||
# Return results | ||
return outputs | ||
|
||
def checkForwardLogProb(self, distribution, num_samples=1, atol=None, rtol=None, **kwargs): | ||
# Check forward | ||
outputs, log_p = self.checkForward(distribution, num_samples, **kwargs) | ||
# Check log prob | ||
log_p_ = self.checkLogProb(distribution, outputs, **kwargs) | ||
# Check consistency | ||
self.assertClose(log_p_, log_p, atol, rtol) |