Skip to content

Commit

Permalink
Continuous bernoulli distribution (take 2) (pytorch#34619)
Browse files Browse the repository at this point in the history
Summary:
We recently had a NeurIPS paper (https://arxiv.org/abs/1907.06845 and https://papers.nips.cc/paper/9484-the-continuous-bernoulli-fixing-a-pervasive-error-in-variational-autoencoders) where we introduce a new [0,1]-supported distribution: the continuous Bernoulli. This pull request implements this distribution in pytorch.
Pull Request resolved: pytorch#34619

Differential Revision: D20403123

Pulled By: ngimel

fbshipit-source-id: d807c7d0d372c6daf6cb6ef09df178bc7491abb2
  • Loading branch information
gabloa authored and facebook-github-bot committed Mar 12, 2020
1 parent 944ea4c commit a74fbea
Show file tree
Hide file tree
Showing 5 changed files with 454 additions and 3 deletions.
9 changes: 9 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ Probability distributions - torch.distributions
:undoc-members:
:show-inheritance:

:hidden:`ContinuousBernoulli`
~~~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torch.distributions.continuous_bernoulli
.. autoclass:: ContinuousBernoulli
:members:
:undoc-members:
:show-inheritance:

:hidden:`Dirichlet`
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
189 changes: 186 additions & 3 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.autograd import grad, gradcheck
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical,
Cauchy, Chi2, Dirichlet, Distribution,
Exponential, ExponentialFamily,
Cauchy, Chi2, ContinuousBernoulli, Dirichlet,
Distribution, Exponential, ExponentialFamily,
FisherSnedecor, Gamma, Geometric, Gumbel,
HalfCauchy, HalfNormal,
Independent, Laplace, LogisticNormal,
Expand Down Expand Up @@ -452,7 +452,13 @@ def is_all_nan(tensor):
{
'loc': torch.tensor([0.0, math.pi / 2], requires_grad=True),
'concentration': torch.tensor([1.0, 10.0], requires_grad=True)
}
},
]),
Example(ContinuousBernoulli, [
{'probs': torch.tensor([0.7, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([0.3], requires_grad=True)},
{'probs': 0.3},
{'logits': torch.tensor([0.], requires_grad=True)},
])
]

Expand Down Expand Up @@ -673,6 +679,11 @@ def is_all_nan(tensor):
'scale': torch.tensor([1.0], requires_grad=True),
'concentration': torch.tensor([-1.0], requires_grad=True)
}
]),
Example(ContinuousBernoulli, [
{'probs': torch.tensor([1.1, 0.2, 0.4], requires_grad=True)},
{'probs': torch.tensor([-0.5], requires_grad=True)},
{'probs': 1.00001},
])
]

Expand Down Expand Up @@ -2402,6 +2413,44 @@ def test_beta_underflow_gpu(self):
self.assertEqual(frac_zeros, 0.5, 0.12)
self.assertEqual(frac_ones, 0.5, 0.12)

def test_continuous_bernoulli(self):
p = torch.tensor([0.7, 0.2, 0.4], requires_grad=True)
r = torch.tensor(0.3, requires_grad=True)
s = 0.3
self.assertEqual(ContinuousBernoulli(p).sample((8,)).size(), (8, 3))
self.assertFalse(ContinuousBernoulli(p).sample().requires_grad)
self.assertEqual(ContinuousBernoulli(r).sample((8,)).size(), (8,))
self.assertEqual(ContinuousBernoulli(r).sample().size(), ())
self.assertEqual(ContinuousBernoulli(r).sample((3, 2)).size(), (3, 2,))
self.assertEqual(ContinuousBernoulli(s).sample().size(), ())
self._gradcheck_log_prob(ContinuousBernoulli, (p,))

def ref_log_prob(idx, val, log_prob):
prob = p[idx]
if prob > 0.499 and prob < 0.501: # using default value of lim here
log_norm_const = math.log(2.) + 4. / 3. * math.pow(prob - 0.5, 2) + 104. / 45. * math.pow(prob - 0.5, 4)
else:
log_norm_const = math.log(2. * math.atanh(1. - 2. * prob) / (1. - 2.0 * prob))
res = val * math.log(prob) + (1. - val) * math.log1p(-prob) + log_norm_const
self.assertEqual(log_prob, res)

self._check_log_prob(ContinuousBernoulli(p), ref_log_prob)
self._check_log_prob(ContinuousBernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)

# check entropy computation
self.assertEqual(ContinuousBernoulli(p).entropy(), torch.tensor([-0.02938, -0.07641, -0.00682]), prec=1e-4)
# entropy below corresponds to the clamped value of prob when using float 64
# the value for float32 should be -1.76898
self.assertEqual(ContinuousBernoulli(torch.tensor([0.0])).entropy(), torch.tensor([-2.58473]))
self.assertEqual(ContinuousBernoulli(s).entropy(), torch.tensor(-0.02938), prec=1e-4)

def test_continuous_bernoulli_3d(self):
p = torch.full((2, 3, 5), 0.5).requires_grad_()
self.assertEqual(ContinuousBernoulli(p).sample().size(), (2, 3, 5))
self.assertEqual(ContinuousBernoulli(p).sample(sample_shape=(2, 5)).size(),
(2, 5, 2, 3, 5))
self.assertEqual(ContinuousBernoulli(p).sample((2,)).size(), (2, 2, 3, 5))

def test_independent_shape(self):
for Dist, params in EXAMPLES:
for param in params:
Expand Down Expand Up @@ -3271,6 +3320,26 @@ def test_laplace_shape_tensor_params(self):
self.assertRaises(ValueError, laplace.log_prob, self.tensor_sample_2)
self.assertEqual(laplace.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))

def test_continuous_bernoulli_shape_scalar_params(self):
continuous_bernoulli = ContinuousBernoulli(0.3)
self.assertEqual(continuous_bernoulli._batch_shape, torch.Size())
self.assertEqual(continuous_bernoulli._event_shape, torch.Size())
self.assertEqual(continuous_bernoulli.sample().size(), torch.Size())
self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.scalar_sample)
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))

def test_continuous_bernoulli_shape_tensor_params(self):
continuous_bernoulli = ContinuousBernoulli(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(continuous_bernoulli._batch_shape, torch.Size((3, 2)))
self.assertEqual(continuous_bernoulli._event_shape, torch.Size(()))
self.assertEqual(continuous_bernoulli.sample().size(), torch.Size((3, 2)))
self.assertEqual(continuous_bernoulli.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(continuous_bernoulli.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, continuous_bernoulli.log_prob, self.tensor_sample_2)
self.assertEqual(continuous_bernoulli.log_prob(torch.ones(3, 1, 1)).size(), torch.Size((3, 3, 2)))


class TestKL(TestCase):

Expand Down Expand Up @@ -3316,6 +3385,7 @@ def __init__(self, probs):
uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7])
uniform_real = pairwise(Uniform, [-2., -1, 0, 2], [-1., 1, 1, 4])
uniform_pareto = pairwise(Uniform, [6.5, 8.5, 6.5, 8.5], [7.5, 7.5, 9.5, 9.5])
continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9])

# These tests should pass with precision = 0.01, but that makes tests very expensive.
# Instead, we test with precision = 0.1 and only test with higher precision locally
Expand Down Expand Up @@ -3374,6 +3444,10 @@ def __init__(self, probs):
(uniform_real, gumbel),
(uniform_real, normal),
(uniform_pareto, pareto),
(continuous_bernoulli, continuous_bernoulli),
(continuous_bernoulli, exponential),
(continuous_bernoulli, normal),
(beta, continuous_bernoulli)
]

self.infinite_examples = [
Expand Down Expand Up @@ -3429,6 +3503,18 @@ def __init__(self, probs):
(Uniform(-1, 2), Exponential(3)),
(Uniform(-1, 2), Gamma(3, 4)),
(Uniform(-1, 2), Pareto(3, 4)),
(ContinuousBernoulli(0.25), Uniform(0.25, 1)),
(ContinuousBernoulli(0.25), Uniform(0, 0.75)),
(ContinuousBernoulli(0.25), Uniform(0.25, 0.75)),
(ContinuousBernoulli(0.25), Pareto(1, 2)),
(Exponential(1), ContinuousBernoulli(0.75)),
(Gamma(1, 2), ContinuousBernoulli(0.75)),
(Gumbel(-1, 2), ContinuousBernoulli(0.75)),
(Laplace(-1, 2), ContinuousBernoulli(0.75)),
(Normal(-1, 2), ContinuousBernoulli(0.75)),
(Uniform(-1, 1), ContinuousBernoulli(0.75)),
(Uniform(0, 2), ContinuousBernoulli(0.75)),
(Uniform(-1, 2), ContinuousBernoulli(0.75))
]

def test_kl_monte_carlo(self):
Expand Down Expand Up @@ -3787,10 +3873,107 @@ def test_multinomial_log_prob_with_logits(self):
log_pdf_prob_0 = multinomial.log_prob(torch.tensor([10, 0], dtype=dtype))
self.assertEqual(log_pdf_prob_0.item(), -inf, allow_inf=True)

def test_continuous_bernoulli_gradient(self):

def expec_val(x, probs=None, logits=None):
assert not (probs is None and logits is None)
if logits is not None:
probs = 1. / (1. + math.exp(-logits))
bern_log_lik = x * math.log(probs) + (1. - x) * math.log1p(-probs)
if probs < 0.499 or probs > 0.501: # using default values of lims here
log_norm_const = math.log(
math.fabs(math.atanh(1. - 2. * probs))) - math.log(math.fabs(1. - 2. * probs)) + math.log(2.)
else:
aux = math.pow(probs - 0.5, 2)
log_norm_const = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * aux) * aux
log_lik = bern_log_lik + log_norm_const
return log_lik

def expec_grad(x, probs=None, logits=None):
assert not (probs is None and logits is None)
if logits is not None:
probs = 1. / (1. + math.exp(-logits))
grad_bern_log_lik = x / probs - (1. - x) / (1. - probs)
if probs < 0.499 or probs > 0.501: # using default values of lims here
grad_log_c = 2. * probs - 4. * (probs - 1.) * probs * math.atanh(1. - 2. * probs) - 1.
grad_log_c /= 2. * (probs - 1.) * probs * (2. * probs - 1.) * math.atanh(1. - 2. * probs)
else:
grad_log_c = 8. / 3. * (probs - 0.5) + 416. / 45. * math.pow(probs - 0.5, 3)
grad = grad_bern_log_lik + grad_log_c
if logits is not None:
grad *= 1. / (1. + math.exp(logits)) - 1. / math.pow(1. + math.exp(logits), 2)
return grad

for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([0.1]),
x=tensor_type([0.1]),
expected_value=tensor_type([expec_val(0.1, probs=0.1)]),
expected_gradient=tensor_type([expec_grad(0.1, probs=0.1)]))

self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([0.1]),
x=tensor_type([1.]),
expected_value=tensor_type([expec_val(1., probs=0.1)]),
expected_gradient=tensor_type([expec_grad(1., probs=0.1)]))

self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([0.4999]),
x=tensor_type([0.9]),
expected_value=tensor_type([expec_val(0.9, probs=0.4999)]),
expected_gradient=tensor_type([expec_grad(0.9, probs=0.4999)]))

self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([1e-4]),
x=tensor_type([1]),
expected_value=tensor_type([expec_val(1, probs=1e-4)]),
expected_gradient=tensor_type(tensor_type([expec_grad(1, probs=1e-4)])),
prec=1e-3)

self._test_pdf_score(dist_class=ContinuousBernoulli,
probs=tensor_type([1 - 1e-4]),
x=tensor_type([0.1]),
expected_value=tensor_type([expec_val(0.1, probs=1 - 1e-4)]),
expected_gradient=tensor_type([expec_grad(0.1, probs=1 - 1e-4)]),
prec=2)

self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([math.log(9999)]),
x=tensor_type([0]),
expected_value=tensor_type([expec_val(0, logits=math.log(9999))]),
expected_gradient=tensor_type([expec_grad(0, logits=math.log(9999))]),
prec=1e-3)

self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([0.001]),
x=tensor_type([0.5]),
expected_value=tensor_type([expec_val(0.5, logits=0.001)]),
expected_gradient=tensor_type([expec_grad(0.5, logits=0.001)]))

def test_continuous_bernoulli_with_logits_underflow(self):
for tensor_type, lim, expected in ([(torch.FloatTensor, -1e38, 2.76898),
(torch.DoubleTensor, -1e308, 3.58473)]):
self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([lim]),
x=tensor_type([0]),
expected_value=tensor_type([expected]),
expected_gradient=tensor_type([0.]))

def test_continuous_bernoulli_with_logits_overflow(self):
for tensor_type, lim, expected in ([(torch.FloatTensor, 1e38, 2.76898),
(torch.DoubleTensor, 1e308, 3.58473)]):
self._test_pdf_score(dist_class=ContinuousBernoulli,
logits=tensor_type([lim]),
x=tensor_type([1]),
expected_value=tensor_type([expected]),
expected_gradient=tensor_type([0.]))


class TestLazyLogitsInitialization(TestCase):
def setUp(self):
super(TestLazyLogitsInitialization, self).setUp()
# ContinuousBernoulli is not tested because log_prob is not computed simply
# from 'logits', but 'probs' is also needed
self.examples = [e for e in EXAMPLES if e.Dist in
(Categorical, OneHotCategorical, Bernoulli, Binomial, Multinomial)]

Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from .cauchy import Cauchy
from .chi2 import Chi2
from .constraint_registry import biject_to, transform_to
from .continuous_bernoulli import ContinuousBernoulli
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exp_family import ExponentialFamily
Expand Down Expand Up @@ -118,6 +119,7 @@
'Categorical',
'Cauchy',
'Chi2',
'ContinuousBernoulli',
'Dirichlet',
'Distribution',
'Exponential',
Expand Down
Loading

0 comments on commit a74fbea

Please sign in to comment.