diff --git a/edward/__init__.py b/edward/__init__.py index c28673559..70b3d049a 100644 --- a/edward/__init__.py +++ b/edward/__init__.py @@ -21,6 +21,6 @@ RandomVariable from edward.util import copy, dot, get_ancestors, get_children, \ get_descendants, get_dims, get_parents, get_session, get_siblings, \ - get_variables, hessian, log_sum_exp, logit, multivariate_rbf, \ - placeholder, random_variables, rbf, set_seed, to_simplex + get_variables, hessian, logit, multivariate_rbf, placeholder, \ + random_variables, rbf, reduce_logmeanexp, set_seed, to_simplex from edward.version import __version__ diff --git a/edward/inferences/klpq.py b/edward/inferences/klpq.py index 11a11c63b..76a928581 100644 --- a/edward/inferences/klpq.py +++ b/edward/inferences/klpq.py @@ -7,7 +7,7 @@ from edward.inferences.variational_inference import VariationalInference from edward.models import RandomVariable, Normal -from edward.util import copy, log_sum_exp +from edward.util import copy class KLpq(VariationalInference): @@ -134,7 +134,7 @@ def build_loss_and_gradients(self, var_list): q_log_prob = tf.stack(q_log_prob) log_w = p_log_prob - q_log_prob - log_w_norm = log_w - log_sum_exp(log_w) + log_w_norm = log_w - tf.reduce_logsumexp(log_w) w_norm = tf.exp(log_w_norm) if var_list is None: diff --git a/edward/util/tensorflow.py b/edward/util/tensorflow.py index 350bd1e14..42dd699a1 100644 --- a/edward/util/tensorflow.py +++ b/edward/util/tensorflow.py @@ -110,57 +110,6 @@ def hessian(y, xs): return tf.stack(mat) -def log_mean_exp(input_tensor, axis=None, keep_dims=False): - """Computes log(mean(exp(elements across dimensions of a tensor))). - - Parameters - ---------- - input_tensor : tf.Tensor - The tensor to reduce. Should have numeric type. - axis : int or list of int, optional - The dimensions to reduce. If `None` (the default), reduces all - dimensions. - keep_dims : bool, optional - If true, retains reduced dimensions with length 1. - - Returns - ------- - tf.Tensor - The reduced tensor. - """ - logsumexp = tf.reduce_logsumexp(input_tensor, axis, keep_dims) - input_tensor = tf.convert_to_tensor(input_tensor) - n = input_tensor.get_shape().as_list() - if axis is None: - n = tf.cast(tf.reduce_prod(n), logsumexp.dtype) - else: - n = tf.cast(tf.reduce_prod(n[axis]), logsumexp.dtype) - - return -tf.log(n) + logsumexp - - -def log_sum_exp(input_tensor, axis=None, keep_dims=False, name=None): - """Compute the ``log_sum_exp`` of elements in a tensor, taking - the sum across axes given by ``axis``. - - Parameters - ---------- - input_tensor : tf.Tensor - The tensor to reduce. Should have numeric type. - axis : int or list of int, optional - The dimensions to reduce. If `None` (the default), reduces all - dimensions. - keep_dims : bool, optional - If true, retains reduced dimensions with length 1. - - Returns - ------- - tf.Tensor - The reduced tensor. - """ - return tf.reduce_logsumexp(input_tensor, axis, keep_dims, name) - - def logit(x): """Evaluate :math:`\log(x / (1 - x))` elementwise. @@ -290,6 +239,35 @@ def rbf(x, y=0.0, sigma=1.0, l=1.0): tf.exp(-1.0 / (2.0 * tf.pow(l, 2.0)) * tf.pow(x - y, 2.0)) +def reduce_logmeanexp(input_tensor, axis=None, keep_dims=False): + """Computes log(mean(exp(elements across dimensions of a tensor))). + + Parameters + ---------- + input_tensor : tf.Tensor + The tensor to reduce. Should have numeric type. + axis : int or list of int, optional + The dimensions to reduce. If `None` (the default), reduces all + dimensions. + keep_dims : bool, optional + If true, retains reduced dimensions with length 1. + + Returns + ------- + tf.Tensor + The reduced tensor. + """ + logsumexp = tf.reduce_logsumexp(input_tensor, axis, keep_dims) + input_tensor = tf.convert_to_tensor(input_tensor) + n = input_tensor.get_shape().as_list() + if axis is None: + n = tf.cast(tf.reduce_prod(n), logsumexp.dtype) + else: + n = tf.cast(tf.reduce_prod(n[axis]), logsumexp.dtype) + + return -tf.log(n) + logsumexp + + def to_simplex(x): """Transform real vector of length ``(K-1)`` to a simplex of dimension ``K`` using a backward stick breaking construction. diff --git a/examples/iwvi.py b/examples/iwvi.py index 011acb80f..29b7c4bae 100644 --- a/examples/iwvi.py +++ b/examples/iwvi.py @@ -14,7 +14,7 @@ from edward.inferences import VariationalInference from edward.models import Bernoulli, Normal, RandomVariable -from edward.util import copy, log_mean_exp +from edward.util import copy, reduce_logmeanexp from scipy.special import expit @@ -78,7 +78,7 @@ def build_loss_and_gradients(self, var_list): log_w += [p_log_prob - q_log_prob] - loss = -log_mean_exp(log_w) + loss = -reduce_logmeanexp(log_w) grads = tf.gradients(loss, [v._ref() for v in var_list]) grads_and_vars = list(zip(grads, var_list)) return loss, grads_and_vars diff --git a/examples/tf_mixture_gaussian.py b/examples/tf_mixture_gaussian.py index b13fc0bae..13ad39088 100644 --- a/examples/tf_mixture_gaussian.py +++ b/examples/tf_mixture_gaussian.py @@ -15,7 +15,7 @@ from edward.models import Dirichlet, Normal, InverseGamma from edward.stats import dirichlet, invgamma, multivariate_normal_diag, norm -from edward.util import get_dims, log_sum_exp +from edward.util import get_dims plt.style.use('ggplot') @@ -68,9 +68,9 @@ def log_prob(self, xs, zs): sigmas[(k * self.D):((k + 1) * self.D)])] matrix = tf.stack(matrix) - # log_sum_exp() along the rows is a vector, whose nth + # log sum exp along the rows is a vector, whose nth # element is the log-likelihood of data point x_n. - vector = log_sum_exp(matrix, 0) + vector = tf.reduce_logsumexp(matrix, 0) # Sum over data points to get the full log-likelihood. log_lik = tf.reduce_sum(vector) diff --git a/examples/tf_mixture_gaussian_laplace.py b/examples/tf_mixture_gaussian_laplace.py index 2d8ecf726..65e6cd2a0 100644 --- a/examples/tf_mixture_gaussian_laplace.py +++ b/examples/tf_mixture_gaussian_laplace.py @@ -13,7 +13,7 @@ from edward.models import PointMass from edward.stats import dirichlet, invgamma, multivariate_normal_diag, norm -from edward.util import get_dims, log_sum_exp +from edward.util import get_dims class MixtureGaussian: @@ -64,9 +64,9 @@ def log_prob(self, xs, zs): sigmas[(k * self.D):((k + 1) * self.D)])] matrix = tf.stack(matrix) - # log_sum_exp() along the rows is a vector, whose nth + # log sum exp along the rows is a vector, whose nth # element is the log-likelihood of data point x_n. - vector = log_sum_exp(matrix, 0) + vector = tf.reduce_logsumexp(matrix, 0) # Sum over data points to get the full log-likelihood. log_lik = tf.reduce_sum(vector) diff --git a/examples/tf_mixture_gaussian_map.py b/examples/tf_mixture_gaussian_map.py index 2602a1f2e..c7efbf309 100644 --- a/examples/tf_mixture_gaussian_map.py +++ b/examples/tf_mixture_gaussian_map.py @@ -13,7 +13,7 @@ from edward.models import PointMass from edward.stats import dirichlet, invgamma, multivariate_normal_diag, norm -from edward.util import get_dims, log_sum_exp +from edward.util import get_dims class MixtureGaussian: @@ -64,9 +64,9 @@ def log_prob(self, xs, zs): sigmas[(k * self.D):((k + 1) * self.D)])] matrix = tf.stack(matrix) - # log_sum_exp() along the rows is a vector, whose nth + # log sum exp along the rows is a vector, whose nth # element is the log-likelihood of data point x_n. - vector = log_sum_exp(matrix, 0) + vector = tf.reduce_logsumexp(matrix, 0) # Sum over data points to get the full log-likelihood. log_lik = tf.reduce_sum(vector) diff --git a/tests/test-util/test_log_sum_exp.py b/tests/test-util/test_log_sum_exp.py deleted file mode 100644 index 0694bf7f0..000000000 --- a/tests/test-util/test_log_sum_exp.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -from edward.util import log_sum_exp - - -class test_log_sum_exp_class(tf.test.TestCase): - - def test_log_sum_exp_1d(self): - with self.test_session(): - x = tf.constant([-1.0, -2.0, -3.0, -4.0]) - self.assertAllClose(log_sum_exp(x).eval(), - -0.5598103014388045) - - def test_log_sum_exp_2d(self): - with self.test_session(): - x = tf.constant([[-1.0], [-2.0], [-3.0], [-4.0]]) - self.assertAllClose(log_sum_exp(x).eval(), - -0.5598103014388045) - x = tf.constant([[-1.0, -2.0], [-3.0, -4.0]]) - self.assertAllClose(log_sum_exp(x).eval(), - -0.5598103014388045) - self.assertAllClose(log_sum_exp(x, 0).eval(), - np.array([-0.87307198895702742, - -1.8730719889570275])) - self.assertAllClose(log_sum_exp(x, 1).eval(), - np.array([-0.68673831248177708, - -2.6867383124817774])) - -if __name__ == '__main__': - tf.test.main()