From f4a0c2c0f1bbff6e9b4d5d4a0796e7645a974321 Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Thu, 16 Mar 2017 07:15:30 -0800 Subject: [PATCH] Deterministic and VectorDeterministic distributions added. Change: 150320911 --- tensorflow/contrib/distributions/BUILD | 17 + tensorflow/contrib/distributions/__init__.py | 3 + .../python/kernel_tests/deterministic_test.py | 295 ++++++++++++++ .../distributions/python/ops/deterministic.py | 383 ++++++++++++++++++ 4 files changed, 698 insertions(+) create mode 100644 tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py create mode 100644 tensorflow/contrib/distributions/python/ops/deterministic.py diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 7e6f05e38ca6bd..9cf15d3c6f9a2d 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -270,6 +270,23 @@ cuda_py_test( ], ) +cuda_py_test( + name = "deterministic_test", + size = "small", + srcs = ["python/kernel_tests/deterministic_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "dirichlet_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index c575ba97a59994..627512c8e8adb8 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -29,6 +29,8 @@ @@Categorical @@Chi2 @@Chi2WithAbsDf +@@Deterministic +@@VectorDeterministic @@Exponential @@ExponentialWithSoftplusRate @@Gamma @@ -94,6 +96,7 @@ from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.conditional_distribution import * from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * +from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.dirichlet import * from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import * from tensorflow.contrib.distributions.python.ops.distribution import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py new file mode 100644 index 00000000000000..90910f3839b1a4 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/deterministic_test.py @@ -0,0 +1,295 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + +rng = np.random.RandomState(0) + + +class DeterministicTest(test.TestCase): + + def testShape(self): + with self.test_session(): + loc = rng.rand(2, 3, 4) + deterministic = deterministic_lib.Deterministic(loc) + + self.assertAllEqual(deterministic.batch_shape_tensor().eval(), (2, 3, 4)) + self.assertAllEqual(deterministic.batch_shape, (2, 3, 4)) + self.assertAllEqual(deterministic.event_shape_tensor().eval(), []) + self.assertEqual(deterministic.event_shape, tensor_shape.TensorShape([])) + + def testInvalidTolRaises(self): + loc = rng.rand(2, 3, 4).astype(np.float32) + deterministic = deterministic_lib.Deterministic( + loc, atol=-1, validate_args=True) + with self.test_session(): + with self.assertRaisesOpError("Condition x >= 0"): + deterministic.prob(0.).eval() + + def testProbWithNoBatchDimsIntegerType(self): + deterministic = deterministic_lib.Deterministic(0) + with self.test_session(): + self.assertAllClose(1, deterministic.prob(0).eval()) + self.assertAllClose(0, deterministic.prob(2).eval()) + self.assertAllClose([1, 0], deterministic.prob([0, 2]).eval()) + + def testProbWithNoBatchDims(self): + deterministic = deterministic_lib.Deterministic(0.) + with self.test_session(): + self.assertAllClose(1., deterministic.prob(0.).eval()) + self.assertAllClose(0., deterministic.prob(2.).eval()) + self.assertAllClose([1., 0.], deterministic.prob([0., 2.]).eval()) + + def testProbWithDefaultTol(self): + loc = [[0., 1.], [2., 3.]] + x = [[0., 1.1], [1.99, 3.]] + deterministic = deterministic_lib.Deterministic(loc) + expected_prob = [[1., 0.], [0., 1.]] + with self.test_session(): + prob = deterministic.prob(x) + self.assertAllEqual((2, 2), prob.get_shape()) + self.assertAllEqual(expected_prob, prob.eval()) + + def testProbWithNonzeroATol(self): + loc = [[0., 1.], [2., 3.]] + x = [[0., 1.1], [1.99, 3.]] + deterministic = deterministic_lib.Deterministic(loc, atol=0.05) + expected_prob = [[1., 0.], [1., 1.]] + with self.test_session(): + prob = deterministic.prob(x) + self.assertAllEqual((2, 2), prob.get_shape()) + self.assertAllEqual(expected_prob, prob.eval()) + + def testProbWithNonzeroATolIntegerType(self): + loc = [[0, 1], [2, 3]] + x = [[0, 2], [4, 2]] + deterministic = deterministic_lib.Deterministic(loc, atol=1) + expected_prob = [[1, 1], [0, 1]] + with self.test_session(): + prob = deterministic.prob(x) + self.assertAllEqual((2, 2), prob.get_shape()) + self.assertAllEqual(expected_prob, prob.eval()) + + def testProbWithNonzeroRTol(self): + loc = [[0., 1.], [100., 100.]] + x = [[0., 1.1], [100.1, 103.]] + deterministic = deterministic_lib.Deterministic(loc, rtol=0.01) + expected_prob = [[1., 0.], [1., 0.]] + with self.test_session(): + prob = deterministic.prob(x) + self.assertAllEqual((2, 2), prob.get_shape()) + self.assertAllEqual(expected_prob, prob.eval()) + + def testProbWithNonzeroRTolIntegerType(self): + loc = [[10, 10, 10], [10, 10, 10]] + x = [[10, 20, 30], [10, 20, 30]] + # Batch 0 will have rtol = 0 + # Batch 1 will have rtol = 1 (100% slack allowed) + deterministic = deterministic_lib.Deterministic(loc, rtol=[[0], [1]]) + expected_prob = [[1, 0, 0], [1, 1, 0]] + with self.test_session(): + prob = deterministic.prob(x) + self.assertAllEqual((2, 3), prob.get_shape()) + self.assertAllEqual(expected_prob, prob.eval()) + + def testCdfWithDefaultTol(self): + loc = [[0., 0.], [0., 0.]] + x = [[-1., -0.1], [-0.01, 1.000001]] + deterministic = deterministic_lib.Deterministic(loc) + expected_cdf = [[0., 0.], [0., 1.]] + with self.test_session(): + cdf = deterministic.cdf(x) + self.assertAllEqual((2, 2), cdf.get_shape()) + self.assertAllEqual(expected_cdf, cdf.eval()) + + def testCdfWithNonzeroATol(self): + loc = [[0., 0.], [0., 0.]] + x = [[-1., -0.1], [-0.01, 1.000001]] + deterministic = deterministic_lib.Deterministic(loc, atol=0.05) + expected_cdf = [[0., 0.], [1., 1.]] + with self.test_session(): + cdf = deterministic.cdf(x) + self.assertAllEqual((2, 2), cdf.get_shape()) + self.assertAllEqual(expected_cdf, cdf.eval()) + + def testCdfWithNonzeroRTol(self): + loc = [[1., 1.], [100., 100.]] + x = [[0.9, 1.], [99.9, 97]] + deterministic = deterministic_lib.Deterministic(loc, rtol=0.01) + expected_cdf = [[0., 1.], [1., 0.]] + with self.test_session(): + cdf = deterministic.cdf(x) + self.assertAllEqual((2, 2), cdf.get_shape()) + self.assertAllEqual(expected_cdf, cdf.eval()) + + def testSampleNoBatchDims(self): + deterministic = deterministic_lib.Deterministic(0.) + for sample_shape in [(), (4,)]: + with self.test_session(): + sample = deterministic.sample(sample_shape) + self.assertAllEqual(sample_shape, sample.get_shape()) + self.assertAllClose( + np.zeros(sample_shape).astype(np.float32), sample.eval()) + + def testSampleWithBatchDims(self): + deterministic = deterministic_lib.Deterministic([0., 0.]) + for sample_shape in [(), (4,)]: + with self.test_session(): + sample = deterministic.sample(sample_shape) + self.assertAllEqual(sample_shape + (2,), sample.get_shape()) + self.assertAllClose( + np.zeros(sample_shape + (2,)).astype(np.float32), sample.eval()) + + def testSampleDynamicWithBatchDims(self): + loc = array_ops.placeholder(np.float32) + sample_shape = array_ops.placeholder(np.int32) + + deterministic = deterministic_lib.Deterministic(loc) + for sample_shape_ in [(), (4,)]: + with self.test_session(): + sample_ = deterministic.sample(sample_shape).eval( + feed_dict={loc: [0., 0.], + sample_shape: sample_shape_}) + self.assertAllClose( + np.zeros(sample_shape_ + (2,)).astype(np.float32), sample_) + + +class VectorDeterministicTest(test.TestCase): + + def testShape(self): + with self.test_session(): + loc = rng.rand(2, 3, 4) + deterministic = deterministic_lib.VectorDeterministic(loc) + + self.assertAllEqual(deterministic.batch_shape_tensor().eval(), (2, 3)) + self.assertAllEqual(deterministic.batch_shape, (2, 3)) + self.assertAllEqual(deterministic.event_shape_tensor().eval(), [4]) + self.assertEqual(deterministic.event_shape, tensor_shape.TensorShape([4])) + + def testInvalidTolRaises(self): + loc = rng.rand(2, 3, 4).astype(np.float32) + deterministic = deterministic_lib.VectorDeterministic( + loc, atol=-1, validate_args=True) + with self.test_session(): + with self.assertRaisesOpError("Condition x >= 0"): + deterministic.prob(loc).eval() + + def testInvalidXRaises(self): + loc = rng.rand(2, 3, 4).astype(np.float32) + deterministic = deterministic_lib.VectorDeterministic( + loc, atol=-1, validate_args=True) + with self.test_session(): + with self.assertRaisesRegexp(ValueError, "must have rank at least 1"): + deterministic.prob(0.).eval() + + def testProbVectorDeterministicWithNoBatchDims(self): + # 0 batch of deterministics on R^1. + deterministic = deterministic_lib.VectorDeterministic([0.]) + with self.test_session(): + self.assertAllClose(1., deterministic.prob([0.]).eval()) + self.assertAllClose(0., deterministic.prob([2.]).eval()) + self.assertAllClose([1., 0.], deterministic.prob([[0.], [2.]]).eval()) + + def testProbWithDefaultTol(self): + # 3 batch of deterministics on R^2. + loc = [[0., 1.], [2., 3.], [4., 5.]] + x = [[0., 1.], [1.9, 3.], [3.99, 5.]] + deterministic = deterministic_lib.VectorDeterministic(loc) + expected_prob = [1., 0., 0.] + with self.test_session(): + prob = deterministic.prob(x) + self.assertAllEqual((3,), prob.get_shape()) + self.assertAllEqual(expected_prob, prob.eval()) + + def testProbWithNonzeroATol(self): + # 3 batch of deterministics on R^2. + loc = [[0., 1.], [2., 3.], [4., 5.]] + x = [[0., 1.], [1.9, 3.], [3.99, 5.]] + deterministic = deterministic_lib.VectorDeterministic(loc, atol=0.05) + expected_prob = [1., 0., 1.] + with self.test_session(): + prob = deterministic.prob(x) + self.assertAllEqual((3,), prob.get_shape()) + self.assertAllEqual(expected_prob, prob.eval()) + + def testProbWithNonzeroRTol(self): + # 3 batch of deterministics on R^2. + loc = [[0., 1.], [1., 1.], [100., 100.]] + x = [[0., 1.], [0.9, 1.], [99.9, 100.1]] + deterministic = deterministic_lib.VectorDeterministic(loc, rtol=0.01) + expected_prob = [1., 0., 1.] + with self.test_session(): + prob = deterministic.prob(x) + self.assertAllEqual((3,), prob.get_shape()) + self.assertAllEqual(expected_prob, prob.eval()) + + def testProbVectorDeterministicWithNoBatchDimsOnRZero(self): + # 0 batch of deterministics on R^0. + deterministic = deterministic_lib.VectorDeterministic( + [], validate_args=True) + with self.test_session(): + self.assertAllClose(1., deterministic.prob([]).eval()) + + def testProbVectorDeterministicWithNoBatchDimsOnRZeroRaisesIfXNotInSameRk( + self): + # 0 batch of deterministics on R^0. + deterministic = deterministic_lib.VectorDeterministic( + [], validate_args=True) + with self.test_session(): + with self.assertRaisesOpError("not defined in the same space"): + deterministic.prob([1.]).eval() + + def testSampleNoBatchDims(self): + deterministic = deterministic_lib.VectorDeterministic([0.]) + for sample_shape in [(), (4,)]: + with self.test_session(): + sample = deterministic.sample(sample_shape) + self.assertAllEqual(sample_shape + (1,), sample.get_shape()) + self.assertAllClose( + np.zeros(sample_shape + (1,)).astype(np.float32), sample.eval()) + + def testSampleWithBatchDims(self): + deterministic = deterministic_lib.VectorDeterministic([[0.], [0.]]) + for sample_shape in [(), (4,)]: + with self.test_session(): + sample = deterministic.sample(sample_shape) + self.assertAllEqual(sample_shape + (2, 1), sample.get_shape()) + self.assertAllClose( + np.zeros(sample_shape + (2, 1)).astype(np.float32), sample.eval()) + + def testSampleDynamicWithBatchDims(self): + loc = array_ops.placeholder(np.float32) + sample_shape = array_ops.placeholder(np.int32) + + deterministic = deterministic_lib.VectorDeterministic(loc) + for sample_shape_ in [(), (4,)]: + with self.test_session(): + sample_ = deterministic.sample(sample_shape).eval( + feed_dict={loc: [[0.], [0.]], + sample_shape: sample_shape_}) + self.assertAllClose( + np.zeros(sample_shape_ + (2, 1)).astype(np.float32), sample_) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py new file mode 100644 index 00000000000000..6faa2728426d20 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -0,0 +1,383 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Deterministic distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.contrib.distributions.python.ops import distribution +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops + +__all__ = [ + "Deterministic", + "VectorDeterministic", +] + + +@six.add_metaclass(abc.ABCMeta) +class _BaseDeterministic(distribution.Distribution): + """Base class for Deterministic distributions.""" + + def __init__(self, + loc, + atol=None, + rtol=None, + is_vector=False, + validate_args=False, + allow_nan_stats=True, + name="_BaseDeterministic"): + """Initialize a batch of `_BaseDeterministic` distributions. + + The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf` + computations, e.g. due to floating-point error. + + ``` + pmf(x; loc) + = 1, if Abs(x - loc) <= atol + rtol * Abs(loc), + = 0, otherwise. + ``` + + Args: + loc: Numeric `Tensor`. The point (or batch of points) on which this + distribution is supported. + atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable + shape. The absolute tolerance for comparing closeness to `loc`. + Default is `0`. + rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable + shape. The relative tolerance for comparing closeness to `loc`. + Default is `0`. + is_vector: Python `bool`. If `True`, this is for `VectorDeterministic`, + else `Deterministic`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + + Raises: + ValueError: If `loc` is a scalar. + """ + parameters = locals() + with ops.name_scope(name, values=[loc, atol, rtol]): + loc = ops.convert_to_tensor(loc, name="loc") + if is_vector and validate_args: + msg = "Argument loc must be at least rank 1." + if loc.get_shape().ndims is not None: + if loc.get_shape().ndims < 1: + raise ValueError(msg) + else: + loc = control_flow_ops.with_dependencies( + [check_ops.assert_rank_at_least(loc, 1, message=msg)], loc) + self._loc = loc + + super(_BaseDeterministic, self).__init__( + dtype=self._loc.dtype, + reparameterization_type=distribution.NOT_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._loc], + name=name) + + self._atol = self._get_tol(atol) + self._rtol = self._get_tol(rtol) + # Avoid using the large broadcast with self.loc if possible. + if rtol is None: + self._slack = self.atol + else: + self._slack = self.atol + self.rtol * math_ops.abs(self.loc) + + def _get_tol(self, tol): + if tol is None: + return ops.convert_to_tensor(0, dtype=self.loc.dtype) + + tol = ops.convert_to_tensor(tol, dtype=self.loc.dtype) + if self.validate_args: + tol = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + tol, message="Argument 'tol' must be non-negative") + ], tol) + return tol + + @property + def loc(self): + """Point (or batch of points) at which this distribution is supported.""" + return self._loc + + @property + def atol(self): + """Absolute tolerance for comparing points to `self.loc`.""" + return self._atol + + @property + def rtol(self): + """Relative tolerance for comparing points to `self.loc`.""" + return self._rtol + + def _mean(self): + return array_ops.identity(self.loc) + + def _variance(self): + return array_ops.zeros_like(self.loc) + + def _mode(self): + return self.mean() + + def _sample_n(self, n, seed=None): # pylint: disable=unused-arg + n_static = tensor_util.constant_value(ops.convert_to_tensor(n)) + if n_static is not None and self.loc.get_shape().ndims is not None: + ones = [1] * self.loc.get_shape().ndims + multiples = [n_static] + ones + else: + ones = array_ops.ones_like(array_ops.shape(self.loc)) + multiples = array_ops.concat(([n], ones), axis=0) + + return array_ops.tile(self.loc[array_ops.newaxis, ...], multiples=multiples) + + +class Deterministic(_BaseDeterministic): + """Scalar `Deterministic` distribution on the real line. + + The scalar `Deterministic` distribution is parameterized by a [batch] point + `loc` on the real line. The distribution is supported at this point only, + and corresponds to a random variable that is constant, equal to `loc`. + + See [Degenerate rv](https://en.wikipedia.org/wiki/Degenerate_distribution). + + #### Mathematical Details + + The probability mass function (pmf) and cumulative distribution function (cdf) + are + + ```none + pmf(x; loc) = 1, if x == loc, else 0 + cdf(x; loc) = 1, if x >= loc, else 0 + ``` + + #### Examples + + ```python + # Initialize a single Deterministic supported at zero. + constant = tf.contrib.distributions.Deterministic(0.) + constant.prob(0.) + ==> 1. + constant.prob(2.) + ==> 0. + + # Initialize a [2, 2] batch of scalar constants. + loc = [[0., 1.], [2., 3.]] + x = [[0., 1.1], [1.99, 3.]] + constant = tf.contrib.distributions.Deterministic(loc) + constant.prob(x) + ==> [[1., 0.], [0., 1.]] + ``` + + """ + + def __init__(self, + loc, + atol=None, + rtol=None, + validate_args=False, + allow_nan_stats=True, + name="Deterministic"): + """Initialize a scalar `Deterministic` distribution. + + The `atol` and `rtol` parameters allow for some slack in `pmf`, `cdf` + computations, e.g. due to floating-point error. + + ``` + pmf(x; loc) + = 1, if Abs(x - loc) <= atol + rtol * Abs(loc), + = 0, otherwise. + ``` + + Args: + loc: Numeric `Tensor` of shape `[B1, ..., Bb]`, with `b >= 0`. + The point (or batch of points) on which this distribution is supported. + atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable + shape. The absolute tolerance for comparing closeness to `loc`. + Default is `0`. + rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable + shape. The relative tolerance for comparing closeness to `loc`. + Default is `0`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + super(Deterministic, self).__init__( + loc, + atol=atol, + rtol=rtol, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=name) + + def _batch_shape_tensor(self): + return array_ops.shape(self.loc) + + def _batch_shape(self): + return self.loc.get_shape() + + def _event_shape_tensor(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _event_shape(self): + return tensor_shape.scalar() + + def _prob(self, x): + return math_ops.cast( + math_ops.abs(x - self.loc) <= self._slack, dtype=self.dtype) + + def _cdf(self, x): + return math_ops.cast(x >= self.loc - self._slack, dtype=self.dtype) + + +class VectorDeterministic(_BaseDeterministic): + """Vector `Deterministic` distribution on `R^k`. + + The `VectorDeterministic` distribution is parameterized by a [batch] point + `loc in R^k`. The distribution is supported at this point only, + and corresponds to a random variable that is constant, equal to `loc`. + + See [Degenerate rv](https://en.wikipedia.org/wiki/Degenerate_distribution). + + #### Mathematical Details + + The probability mass function (pmf) is + + ```none + pmf(x; loc) + = 1, if All[Abs(x - loc) <= atol + rtol * Abs(loc)], + = 0, otherwise. + ``` + + #### Examples + + ```python + # Initialize a single VectorDeterministic supported at [0., 2.] in R^2. + constant = tf.contrib.distributions.Deterministic([0., 2.]) + constant.prob([0., 2.]) + ==> 1. + constant.prob([0., 3.]) + ==> 0. + + # Initialize a [3] batch of constants on R^2. + loc = [[0., 1.], [2., 3.], [4., 5.]] + constant = constant_lib.VectorDeterministic(loc) + constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]]) + ==> [1., 0., 0.] + ``` + + """ + + def __init__(self, + loc, + atol=None, + rtol=None, + validate_args=False, + allow_nan_stats=True, + name="VectorDeterministic"): + """Initialize a `VectorDeterministic` distribution on `R^k`, for `k >= 0`. + + Note that there is only one point in `R^0`, the "point" `[]`. So if `k = 0` + then `self.prob([]) == 1`. + + The `atol` and `rtol` parameters allow for some slack in `pmf` + computations, e.g. due to floating-point error. + + ``` + pmf(x; loc) + = 1, if All[Abs(x - loc) <= atol + rtol * Abs(loc)], + = 0, otherwise + ``` + + Args: + loc: Numeric `Tensor` of shape `[B1, ..., Bb, k]`, with `b >= 0`, `k >= 0` + The point (or batch of points) on which this distribution is supported. + atol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable + shape. The absolute tolerance for comparing closeness to `loc`. + Default is `0`. + rtol: Non-negative `Tensor` of same `dtype` as `loc` and broadcastable + shape. The relative tolerance for comparing closeness to `loc`. + Default is `0`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + super(VectorDeterministic, self).__init__( + loc, + atol=atol, + rtol=rtol, + is_vector=True, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=name) + + def _batch_shape_tensor(self): + return array_ops.shape(self.loc)[:-1] + + def _batch_shape(self): + return self.loc.get_shape()[:-1] + + def _event_shape_tensor(self): + return array_ops.shape(self.loc)[-1] + + def _event_shape(self): + return self.loc.get_shape()[-1:] + + def _prob(self, x): + if self.validate_args: + is_vector_check = check_ops.assert_rank_at_least(x, 1) + right_vec_space_check = check_ops.assert_equal( + self.event_shape_tensor(), + array_ops.gather(array_ops.shape(x), array_ops.rank(x) - 1), + message= + "Argument 'x' not defined in the same space R^k as this distribution") + with ops.control_dependencies([is_vector_check]): + with ops.control_dependencies([right_vec_space_check]): + x = array_ops.identity(x) + return math_ops.cast( + math_ops.reduce_all(math_ops.abs(x - self.loc) <= self._slack, axis=-1), + dtype=self.dtype)