Skip to content

Commit

Permalink
Added tests to two base distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentStimper committed Jan 21, 2023
1 parent f02fce3 commit 8ce961a
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
16 changes: 16 additions & 0 deletions normflows/distributions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ def log_prob(self, z):
"""
raise NotImplementedError

def sample(self, num_samples=1):
"""Samples from base distribution
Args:
num_samples: Number of samples to draw from the distriubtion
Returns:
Samples drawn from the distribution
"""
z, _ = self.forward(num_samples)
return z


class DiagGaussian(BaseDistribution):
"""
Expand All @@ -51,6 +63,8 @@ def __init__(self, shape, trainable=True):
super().__init__()
if isinstance(shape, int):
shape = (shape,)
if isinstance(shape, list):
shape = tuple(shape)
self.shape = shape
self.n_dim = len(shape)
self.d = np.prod(shape)
Expand Down Expand Up @@ -104,6 +118,8 @@ def __init__(self, ndim, ind, scale=None):
"""
super().__init__()
self.ndim = ndim
if isinstance(ind, int):
ind = [ind]

# Set up indices and permutations
self.ndim = ndim
Expand Down
31 changes: 31 additions & 0 deletions normflows/distributions/base_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest
import torch
import numpy as np

from normflows.distributions.base import DiagGaussian, UniformGaussian
from normflows.distributions.distribution_test import DistributionTest


class BaseTest(DistributionTest):
def test_diag_gaussian(self):
for shape in [1, (3,), [2, 3]]:
for num_samples in [1, 3]:
with self.subTest(shape=shape, num_samples=num_samples):
distribution = DiagGaussian(shape)
self.checkForwardLogProb(distribution, num_samples)
_ = self.checkSample(distribution, num_samples)

def test_uniform_gaussian(self):
params = [(2, 1, None), (2, (0,), 0.5 * torch.ones(2)),
(4, [2], None), (3, [2, 0], np.pi * torch.ones(3))]
for ndim, ind, scale in params:
for num_samples in [1, 3]:
with self.subTest(ndim=ndim, ind=ind, scale=scale,
num_samples=num_samples):
distribution = UniformGaussian(ndim, ind, scale)
self.checkForwardLogProb(distribution, num_samples)
_ = self.checkSample(distribution, num_samples)


if __name__ == "__main__":
unittest.main()
53 changes: 53 additions & 0 deletions normflows/distributions/distribution_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest
import torch

from torch.testing import assert_close


class DistributionTest(unittest.TestCase):
"""
Generic test case for distribution modules
"""
def assertClose(self, actual, expected, atol=None, rtol=None):
assert_close(actual, expected, atol=atol, rtol=rtol)

def checkForward(self, distribution, num_samples=1, **kwargs):
# Do forward
outputs, log_p = distribution(num_samples, **kwargs)
# Check type
assert outputs.dtype == log_p.dtype
# Check shape
assert log_p.shape[0] == num_samples
assert outputs.shape[0] == num_samples
# Check dim
assert outputs.dim() > log_p.dim()
# Return results
return outputs, log_p

def checkLogProb(self, distribution, inputs, **kwargs):
# Compute log prob
log_p = distribution.log_prob(inputs, **kwargs)
# Check type
assert log_p.dtype == inputs.dtype
# Check shape
assert log_p.shape[0] == inputs.shape[0]
# Return results
return log_p

def checkSample(self, distribution, num_samples=1):
# Do forward
outputs = distribution.sample(num_samples)
# Check shape
assert outputs.shape[0] == num_samples
# Check dim
assert outputs.dim() > 1
# Return results
return outputs

def checkForwardLogProb(self, distribution, num_samples=1, atol=None, rtol=None, **kwargs):
# Check forward
outputs, log_p = self.checkForward(distribution, num_samples, **kwargs)
# Check log prob
log_p_ = self.checkLogProb(distribution, outputs, **kwargs)
# Check consistency
self.assertClose(log_p_, log_p, atol, rtol)

0 comments on commit 8ce961a

Please sign in to comment.