Skip to content

Commit

Permalink
Backwards-compatible updates to TF Distributions interface; mostly
Browse files Browse the repository at this point in the history
surfacing subclass-private-specialization docstrings.

* subclass specializations now have signatures in Distributions class
  (self-describing interface)
* subclass docstrings of private methods are now properly appended
  to the public methods (i.e., Normal.log_pdf.doc += Normal._log_pdf.doc)
* modified places that return AttributeError to return the right type of error
  (either NotImplemented, TypeError, or ValueError)
Change: 134550708
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed Sep 28, 2016
1 parent 09ecbd9 commit 3cb3439
Show file tree
Hide file tree
Showing 52 changed files with 2,266 additions and 723 deletions.
8 changes: 1 addition & 7 deletions tensorflow/contrib/distributions/python/ops/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,10 @@ def _std(self):
return math_ops.sqrt(self._variance())

def _mode(self):
"""Returns `1` if `p > 1-p` and `0` otherwise."""
return math_ops.cast(self.p > self.q, self.dtype)


distribution_util.append_class_fun_doc(Bernoulli.mode, doc_str="""
Specific notes:
1 if p > 1-p. 0 otherwise.
""")


class BernoulliWithSigmoidP(Bernoulli):
"""Bernoulli with `p = sigmoid(p)`."""

Expand Down
35 changes: 15 additions & 20 deletions tensorflow/contrib/distributions/python/ops/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
from tensorflow.python.ops import random_ops


_beta_prob_note = """
Note that the argument `x` must be a non-negative floating point tensor
whose shape can be broadcast with `self.a` and `self.b`. For fixed leading
dimensions, the last dimension represents counts for the corresponding Beta
distribution in `self.a` and `self.b`. `x` is only legal if `0 < x < 1`.
"""


class Beta(distribution.Distribution):
"""Beta distribution.
Expand Down Expand Up @@ -202,9 +210,11 @@ def _log_prob(self, x):
math_ops.lgamma(self.a_b_sum))
return log_unnormalized_prob - log_normalization

@distribution_util.AppendDocstring(_beta_prob_note)
def _prob(self, x):
return math_ops.exp(self._log_prob(x))

@distribution_util.AppendDocstring(_beta_prob_note)
def _log_cdf(self, x):
return math_ops.log(self._cdf(x))

Expand All @@ -228,6 +238,11 @@ def _variance(self):
def _std(self):
return math_ops.sqrt(self.variance())

@distribution_util.AppendDocstring(
"""Note that the mode for the Beta distribution is only defined
when `a > 1`, `b > 1`. This returns the mode when `a > 1` and `b > 1`,
and `NaN` otherwise. If `self.allow_nan_stats` is `False`, an exception
will be raised rather than returning `NaN`.""")
def _mode(self):
mode = (self.a - 1.)/ (self.a_b_sum - 2.)
if self.allow_nan_stats:
Expand Down Expand Up @@ -261,26 +276,6 @@ def _assert_valid_sample(self, x):
], x)


_prob_note = """
Note that the argument `x` must be a non-negative floating point tensor
whose shape can be broadcast with `self.a` and `self.b`. For fixed leading
dimensions, the last dimension represents counts for the corresponding Beta
distribution in `self.a` and `self.b`. `x` is only legal if `0 < x < 1`.
"""

distribution_util.append_class_fun_doc(Beta.log_prob, doc_str=_prob_note)
distribution_util.append_class_fun_doc(Beta.prob, doc_str=_prob_note)

distribution_util.append_class_fun_doc(Beta.mode, doc_str="""
Note that the mode for the Beta distribution is only defined
when `a > 1`, `b > 1`. This returns the mode when `a > 1` and `b > 1`,
and `NaN` otherwise. If `self.allow_nan_stats` is `False`, an exception
will be raised rather than returning `NaN`.
""")


class BetaWithSoftplusAB(Beta):
"""Beta with softplus transform on `a` and `b`."""

Expand Down
44 changes: 20 additions & 24 deletions tensorflow/contrib/distributions/python/ops/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops

_binomial_prob_note = """
For each batch member of counts `value`, `P[counts]` is the probability that
after sampling `n` draws from this Binomial distribution, the number of
successes is `k`. Note that different sequences of draws can result in the
same counts, thus the probability includes a combinatorial coefficient.
`value` must be a non-negative tensor with dtype `dtype` and whose shape
can be broadcast with `self.p` and `self.n`. `counts` is only legal if it is
less than or equal to `n` and its components are equal to integer
values.
"""


class Binomial(distribution.Distribution):
"""Binomial distribution.
Expand Down Expand Up @@ -172,6 +184,7 @@ def _event_shape(self):
def _get_event_shape(self):
return tensor_shape.scalar()

@distribution_util.AppendDocstring(_binomial_prob_note)
def _log_prob(self, counts):
counts = self._check_counts(counts)
prob_prob = (counts * math_ops.log(self.p) +
Expand All @@ -182,6 +195,7 @@ def _log_prob(self, counts):
log_prob = prob_prob + combinations
return log_prob

@distribution_util.AppendDocstring(_binomial_prob_note)
def _prob(self, counts):
return math_ops.exp(self._log_prob(counts))

Expand All @@ -194,11 +208,16 @@ def _variance(self):
def _std(self):
return math_ops.sqrt(self._variance())

@distribution_util.AppendDocstring(
"""Note that when `(n + 1) * p` is an integer, there are actually two
modes. Namely, `(n + 1) * p` and `(n + 1) * p - 1` are both modes. Here
we return only the larger of the two modes.""")
def _mode(self):
return math_ops.floor((self._n + 1) * self._p)

@distribution_util.AppendDocstring(
"""Check counts for proper shape, values, then return tensor version.""")
def _check_counts(self, counts):
"""Check counts for proper shape, values, then return tensor version."""
counts = ops.convert_to_tensor(counts, name="counts_before_deps")
if not self.validate_args:
return counts
Expand All @@ -209,26 +228,3 @@ def _check_counts(self, counts):
counts, self._n, message="counts are not less than or equal to n."),
distribution_util.assert_integer_form(
counts, message="counts have non-integer components.")], counts)


_prob_note = """
For each batch member of counts `k`, `P[counts]` is the probability that
after sampling `n` draws from this Binomial distribution, the number of
successes is `k`. Note that different sequences of draws can result in the
same counts, thus the probability includes a combinatorial coefficient.
counts: Non-negative tensor with dtype `dtype` and whose shape can be
broadcast with `self.p` and `self.n`. `counts` is only legal if it is
less than or equal to `n` and its components are equal to integer
values.
"""
distribution_util.append_class_fun_doc(Binomial.log_prob, doc_str=_prob_note)
distribution_util.append_class_fun_doc(Binomial.prob, doc_str=_prob_note)

distribution_util.append_class_fun_doc(Binomial.mode, doc_str="""
Note that when `(n + 1) * p` is an integer, there are actually two modes.
Namely, `(n + 1) * p` and `(n + 1) * p - 1` are both modes. Here we return
only the larger of the two modes.
""")
34 changes: 15 additions & 19 deletions tensorflow/contrib/distributions/python/ops/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
from tensorflow.python.ops import special_math_ops


_dirichlet_prob_note = """
Note that the input must be a non-negative tensor with dtype `dtype` and whose
shape can be broadcast with `self.alpha`. For fixed leading dimensions, the
last dimension represents counts for the corresponding Dirichlet distribution
in `self.alpha`. `x` is only legal if it sums up to one.
"""


class Dirichlet(distribution.Distribution):
"""Dirichlet distribution.
Expand Down Expand Up @@ -175,6 +183,7 @@ def _sample_n(self, n, seed=None):
return gamma_sample / math_ops.reduce_sum(
gamma_sample, reduction_indices=[-1], keep_dims=True)

@distribution_util.AppendDocstring(_dirichlet_prob_note)
def _log_prob(self, x):
x = ops.convert_to_tensor(x, name="x")
x = self._assert_valid_sample(x)
Expand All @@ -184,6 +193,7 @@ def _log_prob(self, x):
keep_dims=False) - special_math_ops.lbeta(self.alpha)
return log_prob

@distribution_util.AppendDocstring(_dirichlet_prob_note)
def _prob(self, x):
return math_ops.exp(self._log_prob(x))

Expand Down Expand Up @@ -212,6 +222,11 @@ def _variance(self):
def _std(self):
return math_ops.sqrt(self._variance())

@distribution_util.AppendDocstring(
"""Note that the mode for the Dirichlet distribution is only defined
when `alpha > 1`. This returns the mode when `alpha > 1`,
and NaN otherwise. If `self.allow_nan_stats` is `False`, an exception
will be raised rather than returning `NaN`.""")
def _mode(self):
mode = ((self.alpha - 1.) /
(array_ops.expand_dims(self.alpha_sum, dim=-1) -
Expand All @@ -238,22 +253,3 @@ def _assert_valid_sample(self, x):
array_ops.ones((), dtype=self.dtype),
math_ops.reduce_sum(x, reduction_indices=[-1])),
], x)


_prob_note = """
Note that the input must be a non-negative tensor with dtype `dtype` and whose
shape can be broadcast with `self.alpha`. For fixed leading dimensions, the
last dimension represents counts for the corresponding Dirichlet distribution
in `self.alpha`. `x` is only legal if it sums up to one.
"""
distribution_util.append_class_fun_doc(Dirichlet.log_prob, doc_str=_prob_note)
distribution_util.append_class_fun_doc(Dirichlet.prob, doc_str=_prob_note)

distribution_util.append_class_fun_doc(Dirichlet.mode, doc_str="""
Note that the mode for the Dirichlet distribution is only defined
when `alpha > 1`. This returns the mode when `alpha > 1`,
and NaN otherwise. If `self.allow_nan_stats` is `False`, an exception
will be raised rather than returning `NaN`.
""")
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@
from tensorflow.python.ops import special_math_ops


_dirichlet_multinomial_prob_note = """
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
that after sampling `n` draws from this Dirichlet Multinomial
distribution, the number of draws falling in class `j` is `n_j`. Note that
different sequences of draws can result in the same counts, thus the
probability includes a combinatorial coefficient.
Note that input, "counts", must be a non-negative tensor with dtype `dtype`
and whose shape can be broadcast with `self.alpha`. For fixed leading
dimensions, the last dimension represents counts for the corresponding
Dirichlet Multinomial distribution in `self.alpha`. `counts` is only legal if
it sums up to `n` and its components are equal to integer values.
"""


class DirichletMultinomial(distribution.Distribution):
"""DirichletMultinomial mixture distribution.
Expand Down Expand Up @@ -192,6 +207,7 @@ def _get_event_shape(self):
# Event shape depends only on alpha, not "n".
return self.alpha.get_shape().with_rank_at_least(1)[-1:]

@distribution_util.AppendDocstring(_dirichlet_multinomial_prob_note)
def _log_prob(self, counts):
counts = self._assert_valid_counts(counts)
ordered_prob = (special_math_ops.lbeta(self.alpha + counts) -
Expand All @@ -200,13 +216,31 @@ def _log_prob(self, counts):
self.n, counts)
return log_prob

@distribution_util.AppendDocstring(_dirichlet_multinomial_prob_note)
def _prob(self, counts):
return math_ops.exp(self._log_prob(counts))

def _mean(self):
normalized_alpha = self.alpha / array_ops.expand_dims(self.alpha_sum, -1)
return array_ops.expand_dims(self.n, -1) * normalized_alpha

@distribution_util.AppendDocstring(
"""The variance for each batch member is defined as the following:
```
Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
(n + alpha_0) / (1 + alpha_0)
```
where `alpha_0 = sum_j alpha_j`.
The covariance between elements in a batch is defined as:
```
Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 *
(n + alpha_0) / (1 + alpha_0)
```
""")
def _variance(self):
alpha_sum = array_ops.expand_dims(self.alpha_sum, -1)
normalized_alpha = self.alpha / alpha_sum
Expand Down Expand Up @@ -248,44 +282,3 @@ def _assert_valid_n(self, n, validate_args):
return control_flow_ops.with_dependencies(
[check_ops.assert_non_negative(n),
distribution_util.assert_integer_form(n)], n)


_prob_note = """
For each batch of counts `[n_1,...,n_k]`, `P[counts]` is the probability
that after sampling `n` draws from this Dirichlet Multinomial
distribution, the number of draws falling in class `j` is `n_j`. Note that
different sequences of draws can result in the same counts, thus the
probability includes a combinatorial coefficient.
Note that input, "counts", must be a non-negative tensor with dtype `dtype`
and whose shape can be broadcast with `self.alpha`. For fixed leading
dimensions, the last dimension represents counts for the corresponding
Dirichlet Multinomial distribution in `self.alpha`. `counts` is only legal if
it sums up to `n` and its components are equal to integer values.
"""
distribution_util.append_class_fun_doc(DirichletMultinomial.log_prob,
doc_str=_prob_note)
distribution_util.append_class_fun_doc(DirichletMultinomial.prob,
doc_str=_prob_note)

distribution_util.append_class_fun_doc(DirichletMultinomial.variance,
doc_str="""
The variance for each batch member is defined as the following:
```
Var(X_j) = n * alpha_j / alpha_0 * (1 - alpha_j / alpha_0) *
(n + alpha_0) / (1 + alpha_0)
```
where `alpha_0 = sum_j alpha_j`.
The covariance between elements in a batch is defined as:
```
Cov(X_i, X_j) = -n * alpha_i * alpha_j / alpha_0 ** 2 *
(n + alpha_0) / (1 + alpha_0)
```
""")
Loading

0 comments on commit 3cb3439

Please sign in to comment.