Skip to content

Commit

Permalink
Enable distribution validation if __debug__ (pytorch#48743)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#47123
Follows pyro-ppl/pyro#2701

This turns on `Distribution` validation by default. The motivation is to favor beginners by providing helpful error messages. Advanced users focused on speed can disable validation by calling
```py
torch.distributions.Distribution.set_default_validate_args(False)
```
or by disabling individual distribution validation via `MyDistribution(..., validate_args=False)`.

In practice I have found many beginners forget or do not know about validation. Therefore I have [enabled it by default](pyro-ppl/pyro#2701) in Pyro. I believe PyTorch could also benefit from this change. Indeed validation caught a number of bugs in `.icdf()` methods, in tests, and in PPL benchmarks, all of which have been fixed in this PR.

## Release concerns
- This may slightly slow down some models. Concerned users may disable validation.
- This may cause new `ValueErrors` in models that rely on unsupported behavior, e.g. `Categorical.log_prob()` applied to continuous-valued tensors (only {0,1}-valued tensors are supported).

We should clearly note this change in release notes.

Pull Request resolved: pytorch#48743

Reviewed By: heitorschueroff

Differential Revision: D25304247

Pulled By: neerajprad

fbshipit-source-id: 8d50f28441321ae691f848c55f71aa80cb356b41
  • Loading branch information
fritzo authored and facebook-github-bot committed Jan 5, 2021
1 parent e3c56dd commit 093aca0
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 39 deletions.
7 changes: 4 additions & 3 deletions benchmarks/functional_autograd_benchmark/ppl_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def forward(beta_value: Tensor) -> Tensor:
mu = X.mm(beta_value)

# We need to compute the first and second gradient of this score with respect
# to beta_value.
score = dist.Bernoulli(logits=mu).log_prob(Y).sum() + beta_prior.log_prob(beta_value).sum()
# to beta_value. We disable Bernoulli validation because Y is a relaxed value.
score = (dist.Bernoulli(logits=mu, validate_args=False).log_prob(Y).sum() +
beta_prior.log_prob(beta_value).sum())
return score

return forward, (beta_value.to(device),)
Expand All @@ -40,7 +41,7 @@ def get_robust_regression(device: torch.device) -> GetterReturnType:
Y = torch.rand(N, 1, device=device)

# Predefined nu_alpha and nu_beta, nu_alpha.shape: (1, 1), nu_beta.shape: (1, 1)
nu_alpha = torch.randn(1, 1, device=device)
nu_alpha = torch.rand(1, 1, device=device)
nu_beta = torch.rand(1, 1, device=device)
nu = dist.Gamma(nu_alpha, nu_beta)

Expand Down
43 changes: 22 additions & 21 deletions test/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def _gradcheck_log_prob(self, dist_ctor, ctor_params):
# performs gradient checks on log_prob
distribution = dist_ctor(*ctor_params)
s = distribution.sample()
if s.is_floating_point():
if not distribution.support.is_discrete:
s = s.detach().requires_grad_()

expected_shape = distribution.batch_shape + distribution.event_shape
Expand Down Expand Up @@ -1422,7 +1422,7 @@ def test_uniform(self):
self.assertEqual(Uniform(0.0, 1.0).sample((1,)).size(), (1,))

# Check log_prob computation when value outside range
uniform = Uniform(low_1d, high_1d)
uniform = Uniform(low_1d, high_1d, validate_args=False)
above_high = torch.tensor([4.0])
below_low = torch.tensor([-1.0])
self.assertEqual(uniform.log_prob(above_high).item(), -inf)
Expand Down Expand Up @@ -1517,7 +1517,7 @@ def test_halfcauchy(self):

def test_halfnormal(self):
std = torch.randn(5, 5).abs().requires_grad_()
std_1d = torch.randn(1, requires_grad=True)
std_1d = torch.randn(1).abs().requires_grad_()
std_delta = torch.tensor([1e-5, 1e-5])
self.assertEqual(HalfNormal(std).sample().size(), (5, 5))
self.assertEqual(HalfNormal(std).sample((7,)).size(), (7, 5, 5))
Expand Down Expand Up @@ -1978,6 +1978,8 @@ def gradcheck_func(samples, mu, sigma, prec, scale_tril):
sigma = 0.5 * (sigma + sigma.transpose(-1, -2)) # Ensure symmetry of covariance
if prec is not None:
prec = 0.5 * (prec + prec.transpose(-1, -2)) # Ensure symmetry of precision
if scale_tril is not None:
scale_tril = scale_tril.tril()
return MultivariateNormal(mu, sigma, prec, scale_tril).log_prob(samples)
gradcheck(gradcheck_func, (mvn_samples, mean, covariance, precision, scale_tril), raise_exception=True)

Expand Down Expand Up @@ -2643,7 +2645,7 @@ def test_cdf_log_prob(self):
for i, param in enumerate(params):
dist = Dist(**param)
samples = dist.sample()
if samples.dtype.is_floating_point:
if not dist.support.is_discrete:
samples.requires_grad_()
try:
cdfs = dist.cdf(samples)
Expand Down Expand Up @@ -3050,11 +3052,9 @@ def setUp(self):
self.scalar_sample = 1
self.tensor_sample_1 = torch.ones(3, 2)
self.tensor_sample_2 = torch.ones(3, 2, 3)
Distribution.set_default_validate_args(True)

def tearDown(self):
super(TestDistributionShapes, self).tearDown()
Distribution.set_default_validate_args(False)

def test_entropy_shape(self):
for Dist, params in EXAMPLES:
Expand Down Expand Up @@ -3186,23 +3186,23 @@ def test_one_hot_categorical_shape(self):
self.assertEqual(dist.sample().size(), torch.Size((3,)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
simplex_sample = self.tensor_sample_2 / self.tensor_sample_2.sum(-1, keepdim=True)
self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 2,)))
sample = torch.tensor([0., 1., 0.]).expand(3, 2, 3)
self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 2,)))
self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((3,)))
simplex_sample = torch.ones(3, 3) / 3
self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
sample = torch.eye(3)
self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,)))
# batched
dist = OneHotCategorical(torch.tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size((2,)))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
simplex_sample = self.tensor_sample_1 / self.tensor_sample_1.sum(-1, keepdim=True)
self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3,)))
sample = torch.tensor([0., 1.])
self.assertEqual(dist.log_prob(sample).size(), torch.Size((3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)))
simplex_sample = torch.ones(3, 1, 2) / 2
self.assertEqual(dist.log_prob(simplex_sample).size(), torch.Size((3, 3)))
sample = torch.tensor([0., 1.]).expand(3, 1, 2)
self.assertEqual(dist.log_prob(sample).size(), torch.Size((3, 3)))

def test_cauchy_shape_scalar_params(self):
cauchy = Cauchy(0, 1)
Expand Down Expand Up @@ -3531,12 +3531,15 @@ def __init__(self, probs):
[0.2, 0.7, 0.1],
[0.33, 0.33, 0.34],
[0.2, 0.2, 0.6]])
pareto = pairwise(Pareto, [2.5, 4.0, 2.5, 4.0], [2.25, 3.75, 2.25, 3.75])
pareto = (Pareto(torch.tensor([2.5, 4.0, 2.5, 4.0]).expand(4, 4),
torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4)),
Pareto(torch.tensor([2.25, 3.75, 2.25, 3.8]).expand(4, 4),
torch.tensor([2.25, 3.75, 2.25, 3.75]).expand(4, 4)))
poisson = pairwise(Poisson, [0.3, 1.0, 5.0, 10.0])
uniform_within_unit = pairwise(Uniform, [0.15, 0.95, 0.2, 0.8], [0.1, 0.9, 0.25, 0.75])
uniform_within_unit = pairwise(Uniform, [0.1, 0.9, 0.2, 0.75], [0.15, 0.95, 0.25, 0.8])
uniform_positive = pairwise(Uniform, [1, 1.5, 2, 4], [1.2, 2.0, 3, 7])
uniform_real = pairwise(Uniform, [-2., -1, 0, 2], [-1., 1, 1, 4])
uniform_pareto = pairwise(Uniform, [6.5, 8.5, 6.5, 8.5], [7.5, 7.5, 9.5, 9.5])
uniform_pareto = pairwise(Uniform, [6.5, 7.5, 6.5, 8.5], [7.5, 8.5, 9.5, 9.5])
continuous_bernoulli = pairwise(ContinuousBernoulli, [0.1, 0.2, 0.5, 0.9])

# These tests should pass with precision = 0.01, but that makes tests very expensive.
Expand Down Expand Up @@ -4148,8 +4151,8 @@ def test_lazy_logits_initialization(self):
probs = param.pop('probs')
param['logits'] = probs_to_logits(probs)
dist = Dist(**param)
shape = (1,) if not dist.event_shape else dist.event_shape
dist.log_prob(torch.ones(shape))
# Create new instance to generate a valid sample
dist.log_prob(Dist(**param).sample())
message = 'Failed for {} example 0/{}'.format(Dist.__name__, len(params))
self.assertFalse('probs' in vars(dist), msg=message)
try:
Expand Down Expand Up @@ -4455,7 +4458,6 @@ def test_stack_transform(self):
class TestValidation(TestCase):
def setUp(self):
super(TestCase, self).setUp()
Distribution.set_default_validate_args(True)

def test_valid(self):
for Dist, params in EXAMPLES:
Expand All @@ -4475,7 +4477,6 @@ def test_invalid(self):

def tearDown(self):
super(TestValidation, self).tearDown()
Distribution.set_default_validate_args(False)


class TestJit(TestCase):
Expand Down
2 changes: 0 additions & 2 deletions torch/distributions/cauchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ def cdf(self, value):
return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5

def icdf(self, value):
if self._validate_args:
self._validate_sample(value)
return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc

def entropy(self):
Expand Down
27 changes: 27 additions & 0 deletions torch/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
- ``constraints.boolean``
- ``constraints.cat``
- ``constraints.corr_cholesky``
- ``constraints.dependent``
- ``constraints.greater_than(lower_bound)``
- ``constraints.greater_than_eq(lower_bound)``
- ``constraints.integer_interval(lower_bound, upper_bound)``
- ``constraints.interval(lower_bound, upper_bound)``
- ``constraints.less_than(upper_bound)``
- ``constraints.lower_cholesky``
- ``constraints.lower_triangular``
- ``constraints.nonnegative_integer``
- ``constraints.one_hot``
- ``constraints.positive``
- ``constraints.positive_definite``
- ``constraints.positive_integer``
Expand Down Expand Up @@ -57,6 +61,8 @@ class Constraint(object):
A constraint object represents a region over which a variable is valid,
e.g. within which a variable can be optimized.
"""
is_discrete = False

def check(self, value):
"""
Returns a byte tensor of `sample_shape + batch_shape` indicating
Expand Down Expand Up @@ -103,14 +109,30 @@ class _Boolean(Constraint):
"""
Constrain to the two values `{0, 1}`.
"""
is_discrete = True

def check(self, value):
return (value == 0) | (value == 1)


class _OneHot(Constraint):
"""
Constrain to one-hot vectors.
"""
is_discrete = True

def check(self, value):
is_boolean = (value == 0) | (value == 1)
is_normalized = value.sum(-1).eq(1)
return is_boolean.all(-1) & is_normalized


class _IntegerInterval(Constraint):
"""
Constrain to an integer interval `[lower_bound, upper_bound]`.
"""
is_discrete = True

def __init__(self, lower_bound, upper_bound):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
Expand All @@ -128,6 +150,8 @@ class _IntegerLessThan(Constraint):
"""
Constrain to an integer interval `(-inf, upper_bound]`.
"""
is_discrete = True

def __init__(self, upper_bound):
self.upper_bound = upper_bound

Expand All @@ -144,6 +168,8 @@ class _IntegerGreaterThan(Constraint):
"""
Constrain to an integer interval `[lower_bound, inf)`.
"""
is_discrete = True

def __init__(self, lower_bound):
self.lower_bound = lower_bound

Expand Down Expand Up @@ -358,6 +384,7 @@ def check(self, value):
dependent = _Dependent()
dependent_property = _DependentProperty
boolean = _Boolean()
one_hot = _OneHot()
nonnegative_integer = _IntegerGreaterThan(0)
positive_integer = _IntegerGreaterThan(1)
integer_interval = _IntegerInterval
Expand Down
2 changes: 0 additions & 2 deletions torch/distributions/continuous_bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ def cdf(self, value):
torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs))

def icdf(self, value):
if self._validate_args:
self._validate_sample(value)
cut_probs = self._cut_probs()
return torch.where(
self._outside_unstable_region(),
Expand Down
13 changes: 12 additions & 1 deletion torch/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@ class Distribution(object):

has_rsample = False
has_enumerate_support = False
_validate_args = False
_validate_args = __debug__

@staticmethod
def set_default_validate_args(value):
"""
Sets whether validation is enabled or disabled.
The default behavior mimics Python's ``assert`` statement: validation
is on by default, but is disabled if Python is run in optimized mode
(via ``python -O``). Validation may be expensive, so you may want to
disable it once a model is working.
Args:
value (bool): Whether to enable validation.
"""
if value not in [True, False]:
raise ValueError
Distribution._validate_args = value
Expand Down
2 changes: 0 additions & 2 deletions torch/distributions/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ def cdf(self, value):
return 1 - torch.exp(-self.rate * value)

def icdf(self, value):
if self._validate_args:
self._validate_sample(value)
return -torch.log(1 - value) / self.rate

def entropy(self):
Expand Down
2 changes: 0 additions & 2 deletions torch/distributions/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ def cdf(self, value):
return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1(-(value - self.loc).abs() / self.scale)

def icdf(self, value):
if self._validate_args:
self._validate_sample(value)
term = value - 0.5
return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs())

Expand Down
4 changes: 3 additions & 1 deletion torch/distributions/negative_binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ def param_shape(self):

@lazy_property
def _gamma(self):
# Note we avoid validating because self.total_count can be zero.
return torch.distributions.Gamma(concentration=self.total_count,
rate=torch.exp(-self.logits))
rate=torch.exp(-self.logits),
validate_args=False)

def sample(self, sample_shape=torch.Size()):
with torch.no_grad():
Expand Down
2 changes: 0 additions & 2 deletions torch/distributions/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def cdf(self, value):
return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)))

def icdf(self, value):
if self._validate_args:
self._validate_sample(value)
return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)

def entropy(self):
Expand Down
2 changes: 1 addition & 1 deletion torch/distributions/one_hot_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OneHotCategorical(Distribution):
"""
arg_constraints = {'probs': constraints.simplex,
'logits': constraints.real}
support = constraints.simplex
support = constraints.one_hot
has_enumerate_support = True

def __init__(self, probs=None, logits=None, validate_args=None):
Expand Down
2 changes: 0 additions & 2 deletions torch/distributions/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ def cdf(self, value):
return result.clamp(min=0, max=1)

def icdf(self, value):
if self._validate_args:
self._validate_sample(value)
result = value * (self.high - self.low) + self.low
return result

Expand Down

0 comments on commit 093aca0

Please sign in to comment.