From fae636e0486a43974629f718f9b210cfebdf2a50 Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Sun, 5 Mar 2017 17:07:59 -0500 Subject: [PATCH] Laplace approximation now uses Multivariate Normal's (#506) * tease out Laplace into new file laplace.py * replace hessian utility fn with tf.hessians * update laplace to work with MultivariateNormal* * add unit test; update docs * clean up code on pointmass vs normal * revise docs --- docs/tex/bib.bib | 9 ++ docs/tex/tutorials/map-laplace.tex | 12 +- edward/__init__.py | 4 +- edward/inferences/__init__.py | 1 + edward/inferences/laplace.py | 136 +++++++++++++++++++++ edward/inferences/map.py | 27 +--- edward/inferences/variational_inference.py | 41 +++---- edward/util/tensorflow.py | 58 --------- tests/test-inferences/test_laplace.py | 120 ++++++++++++++++++ tests/test-util/test_hessian.py | 68 ----------- 10 files changed, 295 insertions(+), 181 deletions(-) create mode 100644 edward/inferences/laplace.py create mode 100644 tests/test-inferences/test_laplace.py delete mode 100644 tests/test-util/test_hessian.py diff --git a/docs/tex/bib.bib b/docs/tex/bib.bib index 28c44ca5e..16bd7f55c 100644 --- a/docs/tex/bib.bib +++ b/docs/tex/bib.bib @@ -652,3 +652,12 @@ @article{marin2012approximate number = {6}, pages = {1167--1180} } + +@article{fisher1925theory, +author = {Fisher, R A}, +title = {{Theory of statistical estimation}}, +journal = {Mathematical Proceedings of the Cambridge Philosophical Society}, +year = {1925}, +volume = {22}, +number = {5} +} diff --git a/docs/tex/tutorials/map-laplace.tex b/docs/tex/tutorials/map-laplace.tex index 54a19cc0b..3dc04fe22 100644 --- a/docs/tex/tutorials/map-laplace.tex +++ b/docs/tex/tutorials/map-laplace.tex @@ -20,13 +20,16 @@ \subsection{Laplace approximation} &\approx \text{Normal}(\mathbf{z}\;;\; \mathbf{z}_\text{MAP}, \Lambda^{-1}). \end{align*} -This requires computing a precision matrix $\Lambda$. The Laplace approximation -uses the Hessian of the log joint density at the MAP estimate, -defined component-wise as +This requires computing a precision matrix $\Lambda$. Derived from a +Taylor expansion, the Laplace approximation uses the Hessian of the +negative log joint density at the MAP estimate. For flat priors +(equivalent to maximum likelihood), the precision matrix is known +as the observed Fisher information \citep{fisher1925theory}. +It is defined component-wise as \begin{align*} \Lambda_{ij} &= - \frac{\partial^2 \log p(\mathbf{x}, \mathbf{z})}{\partial z_i \partial z_j}. + \frac{\partial^2}{\partial z_i \partial z_j} -\log p(\mathbf{x}, \mathbf{z}). \end{align*} Edward uses automatic differentiation, specifically with TensorFlow's computational graphs, making this gradient computation both simple and @@ -37,4 +40,3 @@ \subsection{Laplace approximation} implementation in Edward's code base. \subsubsection{References}\label{references} - diff --git a/edward/__init__.py b/edward/__init__.py index 70b3d049a..fee19e808 100644 --- a/edward/__init__.py +++ b/edward/__init__.py @@ -21,6 +21,6 @@ RandomVariable from edward.util import copy, dot, get_ancestors, get_children, \ get_descendants, get_dims, get_parents, get_session, get_siblings, \ - get_variables, hessian, logit, multivariate_rbf, placeholder, \ - random_variables, rbf, reduce_logmeanexp, set_seed, to_simplex + get_variables, logit, multivariate_rbf, placeholder, random_variables, \ + rbf, reduce_logmeanexp, set_seed, to_simplex from edward.version import __version__ diff --git a/edward/inferences/__init__.py b/edward/inferences/__init__.py index 7577d15f2..f02776186 100644 --- a/edward/inferences/__init__.py +++ b/edward/inferences/__init__.py @@ -8,6 +8,7 @@ from edward.inferences.inference import * from edward.inferences.klpq import * from edward.inferences.klqp import * +from edward.inferences.laplace import * from edward.inferences.map import * from edward.inferences.metropolis_hastings import * from edward.inferences.monte_carlo import * diff --git a/edward/inferences/laplace.py b/edward/inferences/laplace.py new file mode 100644 index 000000000..5effc45fc --- /dev/null +++ b/edward/inferences/laplace.py @@ -0,0 +1,136 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six +import tensorflow as tf + +from edward.inferences.map import MAP +from edward.models import \ + MultivariateNormalCholesky, MultivariateNormalDiag, \ + MultivariateNormalFull, PointMass, RandomVariable +from edward.util import get_session, get_variables + + +class Laplace(MAP): + """Laplace approximation (Laplace, 1774). + + It approximates the posterior distribution using a multivariate + normal distribution centered at the mode of the posterior. + + We implement this by running ``MAP`` to find the posterior mode. + This forms the mean of the normal approximation. We then compute the + inverse Hessian at the mode of the posterior. This forms the + covariance of the normal approximation. + """ + def __init__(self, latent_vars, data=None, model_wrapper=None): + """ + Parameters + ---------- + latent_vars : list of RandomVariable or + dict of RandomVariable to RandomVariable + Collection of random variables to perform inference on. If list, + each random variable will be implictly optimized using a + ``MultivariateNormalCholesky`` random variable that is defined + internally (with unconstrained support). If dictionary, each + random variable must be a ``MultivariateNormalCholesky``, + ``MultivariateNormalFull``, or ``MultivariateNormalDiag`` random + variable. + + Notes + ----- + If ``MultivariateNormalDiag`` random variables are specified as + approximations, then the Laplace approximation will only produce + the diagonal. This does not capture correlation among the + variables but it does not require a potentially expensive matrix + inversion. + + Examples + -------- + >>> X = tf.placeholder(tf.float32, [N, D]) + >>> w = Normal(mu=tf.zeros(D), sigma=tf.ones(D)) + >>> y = Normal(mu=ed.dot(X, w), sigma=tf.ones(N)) + >>> + >>> qw = MultivariateNormalFull(mu=tf.Variable(tf.random_normal([D])), + >>> sigma=tf.Variable(tf.random_normal([D, D]))) + >>> + >>> inference = ed.Laplace({w: qw}, data={X: X_train, y: y_train}) + """ + if isinstance(latent_vars, list): + with tf.variable_scope("posterior"): + if model_wrapper is None: + latent_vars = {rv: MultivariateNormalCholesky( + mu=tf.Variable(tf.random_normal(rv.batch_shape())), + chol=tf.Variable(tf.random_normal( + rv.get_batch_shape().concatenate(rv.get_batch_shape()[-1])))) + for rv in latent_vars} + elif len(latent_vars) == 1: + latent_vars = {latent_vars[0]: MultivariateNormalCholesky( + mu=tf.Variable(tf.random_normal([model_wrapper.n_vars])), + chol=tf.Variable(tf.random_normal([model_wrapper.n_vars] * 2)))} + elif len(latent_vars) == 0: + latent_vars = {} + else: + raise NotImplementedError("A list of more than one element is " + "not supported. See documentation.") + elif isinstance(latent_vars, dict): + for qz in six.itervalues(latent_vars): + if not isinstance( + qz, (MultivariateNormalCholesky, MultivariateNormalDiag, + MultivariateNormalFull)): + raise TypeError("Posterior approximation must consist of only " + "MultivariateCholesky, MultivariateNormalDiag, " + "or MultivariateNormalFull random variables.") + + # call grandparent's method; avoid parent (MAP) + super(MAP, self).__init__(latent_vars, data, model_wrapper) + + def initialize(self, var_list=None, *args, **kwargs): + # Store latent variables in a temporary attribute; MAP will + # optimize ``PointMass`` random variables, which subsequently + # optimizes mean parameters of the normal approximations. + self.latent_vars_normal = self.latent_vars.copy() + self.latent_vars = {z: PointMass(params=qz.mu) + for z, qz in six.iteritems(self.latent_vars_normal)} + super(Laplace, self).initialize(var_list, *args, **kwargs) + + def finalize(self, feed_dict=None): + """Function to call after convergence. + + Computes the Hessian at the mode. + + Parameters + ---------- + feed_dict : dict, optional + Feed dictionary for a TensorFlow session run during evaluation + of Hessian. It is used to feed placeholders that are not fed + during initialization. + """ + if feed_dict is None: + feed_dict = {} + + for key, value in six.iteritems(self.data): + if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type: + feed_dict[key] = value + + var_list = list(six.itervalues(self.latent_vars)) + hessians = tf.hessians(self.loss, var_list) + + assign_ops = [] + for z, hessian in zip(six.iterkeys(self.latent_vars), hessians): + qz = self.latent_vars_normal[z] + sigma_var = get_variables(qz.sigma)[0] + if isinstance(qz, MultivariateNormalCholesky): + sigma = tf.matrix_inverse(tf.cholesky(hessian)) + elif isinstance(qz, MultivariateNormalDiag): + sigma = 1.0 / tf.diag_part(hessian) + else: # qz is MultivariateNormalFull + sigma = tf.matrix_inverse(hessian) + + assign_ops.append(sigma_var.assign(sigma)) + + sess = get_session() + sess.run(assign_ops, feed_dict) + self.latent_vars = self.latent_vars_normal.copy() + del self.latent_vars_normal + super(Laplace, self).finalize() diff --git a/edward/inferences/map.py b/edward/inferences/map.py index 52850c540..8d804a914 100644 --- a/edward/inferences/map.py +++ b/edward/inferences/map.py @@ -7,7 +7,7 @@ from edward.inferences.variational_inference import VariationalInference from edward.models import RandomVariable, PointMass -from edward.util import copy, hessian +from edward.util import copy class MAP(VariationalInference): @@ -143,28 +143,3 @@ def build_loss_and_gradients(self, var_list): grads = tf.gradients(loss, [v._ref() for v in var_list]) grads_and_vars = list(zip(grads, var_list)) return loss, grads_and_vars - - -class Laplace(MAP): - """Laplace approximation. - - It approximates the posterior distribution using a normal - distribution centered at the mode of the posterior. - """ - def __init__(self, *args, **kwargs): - super(Laplace, self).__init__(*args, **kwargs) - - def finalize(self): - """Function to call after convergence. - - Computes the Hessian at the mode. - """ - # use only a batch of data to estimate hessian - x = self.data - z = {z: qz.value() for z, qz in six.iteritems(self.latent_vars)} - var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, - scope='posterior') - inv_cov = hessian(self.model_wrapper.log_prob(x, z), var_list) - print("Precision matrix:") - print(inv_cov.eval()) - super(Laplace, self).finalize() diff --git a/edward/inferences/variational_inference.py b/edward/inferences/variational_inference.py index 3490e52ec..961d32bd6 100644 --- a/edward/inferences/variational_inference.py +++ b/edward/inferences/variational_inference.py @@ -44,28 +44,25 @@ def initialize(self, optimizer=None, var_list=None, use_prettytensor=False, """ super(VariationalInference, self).initialize(*args, **kwargs) - if var_list is None: - if self.model_wrapper is None: - # Traverse random variable graphs to get default list of variables. - var_list = set([]) - trainables = tf.trainable_variables() - for z, qz in six.iteritems(self.latent_vars): - if isinstance(z, RandomVariable): - var_list.update(get_variables(z, collection=trainables)) - - var_list.update(get_variables(qz, collection=trainables)) - - for x, qx in six.iteritems(self.data): - if isinstance(x, RandomVariable) and \ - not isinstance(qx, RandomVariable): - var_list.update(get_variables(x, collection=trainables)) - - var_list = list(var_list) - else: - # Variables may not be instantiated for model wrappers until - # their methods are first called. For now, hard-code - # ``var_list`` inside build_losses. - var_list = None + # Variables may not be instantiated for model wrappers until + # their methods are first called. For now, hard-code + # ``var_list`` inside ``build_loss_and_gradients``. + if var_list is None and self.model_wrapper is None: + # Traverse random variable graphs to get default list of variables. + var_list = set() + trainables = tf.trainable_variables() + for z, qz in six.iteritems(self.latent_vars): + if isinstance(z, RandomVariable): + var_list.update(get_variables(z, collection=trainables)) + + var_list.update(get_variables(qz, collection=trainables)) + + for x, qx in six.iteritems(self.data): + if isinstance(x, RandomVariable) and \ + not isinstance(qx, RandomVariable): + var_list.update(get_variables(x, collection=trainables)) + + var_list = list(var_list) self.loss, grads_and_vars = self.build_loss_and_gradients(var_list) diff --git a/edward/util/tensorflow.py b/edward/util/tensorflow.py index 42dd699a1..348a86889 100644 --- a/edward/util/tensorflow.py +++ b/edward/util/tensorflow.py @@ -52,64 +52,6 @@ def dot(x, y): return tf.reshape(tf.matmul(mat, tf.expand_dims(vec, 1)), [-1]) -def hessian(y, xs): - """Calculate Hessian of y with respect to each x in xs. - - Parameters - ---------- - y : tf.Tensor - Tensor to calculate Hessian of. - xs : list of tf.Variable - List of TensorFlow variables to calculate with respect to. - The variables can have different shapes. - - Returns - ------- - tf.Tensor - A 2-D tensor where each row is - .. math:: \partial_{xs} ( [ \partial_{xs} y ]_j ). - - Raises - ------ - InvalidArgumentError - If the inputs have Inf or NaN values. - """ - y = tf.convert_to_tensor(y) - dependencies = [tf.verify_tensor_all_finite(y, msg='')] - dependencies.extend([tf.verify_tensor_all_finite(x, msg='') for x in xs]) - - with tf.control_dependencies(dependencies): - # Calculate flattened vector grad_{xs} y. - grads = tf.gradients(y, xs) - grads = [tf.reshape(grad, [-1]) for grad in grads] - grads = tf.concat(grads, 0) - # Loop over each element in the vector. - mat = [] - d = grads.get_shape()[0] - if not isinstance(d, int): - d = grads.eval().shape[0] - - for j in range(d): - # Calculate grad_{xs} ( [ grad_{xs} y ]_j ). - gradjgrads = tf.gradients(grads[j], xs) - # Flatten into vector. - hi = [] - for l in range(len(xs)): - hij = gradjgrads[l] - # return 0 if gradient doesn't exist; TensorFlow returns None - if hij is None: - hij = tf.zeros(xs[l].get_shape(), dtype=tf.float32) - - hij = tf.reshape(hij, [-1]) - hi.append(hij) - - hi = tf.concat(hi, 0) - mat.append(hi) - - # Form matrix where each row is grad_{xs} ( [ grad_{xs} y ]_j ). - return tf.stack(mat) - - def logit(x): """Evaluate :math:`\log(x / (1 - x))` elementwise. diff --git a/tests/test-inferences/test_laplace.py b/tests/test-inferences/test_laplace.py new file mode 100644 index 000000000..383b4b0ad --- /dev/null +++ b/tests/test-inferences/test_laplace.py @@ -0,0 +1,120 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import edward as ed +import numpy as np +import tensorflow as tf + +from edward.models import \ + MultivariateNormalCholesky, MultivariateNormalDiag, \ + MultivariateNormalFull, Normal + + +def build_toy_dataset(N, w, noise_std=0.1): + D = len(w) + x = np.random.randn(N, D).astype(np.float32) + y = np.dot(x, w) + np.random.normal(0, noise_std, size=N) + return x, y + + +class test_laplace_class(tf.test.TestCase): + + def _setup(self): + N = 250 # number of data points + D = 5 # number of features + + # DATA + w_true = np.ones(D) * 5.0 + X_train, y_train = build_toy_dataset(N, w_true) + + # MODEL + X = tf.placeholder(tf.float32, [N, D]) + w = Normal(mu=tf.zeros(D), sigma=tf.ones(D)) + b = Normal(mu=tf.zeros(1), sigma=tf.ones(1)) + y = Normal(mu=ed.dot(X, w) + b, sigma=tf.ones(N)) + + return N, D, w_true, X_train, y_train, X, w, b, y + + def _test(self, sess, qw, qb, w_true): + qw_mu, qb_mu, qw_sigma_det, qb_sigma_det = \ + sess.run([qw.mu, qb.mu, qw.sigma_det(), qb.sigma_det()]) + + self.assertAllClose(qw_mu, w_true, atol=0.5) + self.assertAllClose(qb_mu, np.array([0.0]), atol=0.5) + self.assertAllClose(qw_sigma_det, 0.0, atol=0.1) + self.assertAllClose(qb_sigma_det, 0.0, atol=0.1) + + def test_list(self): + with self.test_session() as sess: + N, D, w_true, X_train, y_train, X, w, b, y = self._setup() + + # INFERENCE + inference = ed.Laplace([w, b], data={X: X_train, y: y_train}) + inference.run(n_iter=100) + + qw = inference.latent_vars[w] + qb = inference.latent_vars[b] + self._test(sess, qw, qb, w_true) + + def test_multivariate_normal_cholesky(self): + with self.test_session() as sess: + N, D, w_true, X_train, y_train, X, w, b, y = self._setup() + + # INFERENCE. Initialize sigma's at identity to verify if we + # learned an approximately zero determinant. + qw = MultivariateNormalCholesky( + mu=tf.Variable(tf.random_normal([D])), + chol=tf.Variable(tf.diag(tf.ones(D)))) + qb = MultivariateNormalCholesky( + mu=tf.Variable(tf.random_normal([1])), + chol=tf.Variable(tf.diag(tf.ones(1)))) + + inference = ed.Laplace({w: qw, b: qb}, data={X: X_train, y: y_train}) + inference.run(n_iter=100) + + self._test(sess, qw, qb, w_true) + + def test_multivariate_normal_diag(self): + with self.test_session() as sess: + N, D, w_true, X_train, y_train, X, w, b, y = self._setup() + + # INFERENCE. Initialize sigma's at identity to verify if we + # learned an approximately zero determinant. + qw = MultivariateNormalDiag( + mu=tf.Variable(tf.random_normal([D])), + diag_stdev=tf.Variable(tf.ones(D))) + qb = MultivariateNormalDiag( + mu=tf.Variable(tf.random_normal([1])), + diag_stdev=tf.Variable(tf.ones(1))) + + inference = ed.Laplace({w: qw, b: qb}, data={X: X_train, y: y_train}) + inference.run(n_iter=100) + + self._test(sess, qw, qb, w_true) + self.assertAllClose(qw.sigma.eval(), + tf.diag(tf.diag_part(qw.sigma)).eval()) + self.assertAllClose(qb.sigma.eval(), + tf.diag(tf.diag_part(qb.sigma)).eval()) + + def test_multivariate_normal_full(self): + with self.test_session() as sess: + N, D, w_true, X_train, y_train, X, w, b, y = self._setup() + + # INFERENCE. Initialize sigma's at identity to verify if we + # learned an approximately zero determinant. + qw = MultivariateNormalFull( + mu=tf.Variable(tf.random_normal([D])), + sigma=tf.Variable(tf.diag(tf.ones(D)))) + qb = MultivariateNormalFull( + mu=tf.Variable(tf.random_normal([1])), + sigma=tf.Variable(tf.diag(tf.ones(1)))) + + inference = ed.Laplace({w: qw, b: qb}, data={X: X_train, y: y_train}) + inference.run(n_iter=100) + + self._test(sess, qw, qb, w_true) + +if __name__ == '__main__': + ed.set_seed(42) + tf.test.main() diff --git a/tests/test-util/test_hessian.py b/tests/test-util/test_hessian.py deleted file mode 100644 index fdc520c4c..000000000 --- a/tests/test-util/test_hessian.py +++ /dev/null @@ -1,68 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf - -from edward.util import hessian - - -class test_hessian_class(tf.test.TestCase): - - def test_hessian_0d(self): - with self.test_session(): - x1 = tf.Variable(tf.random_normal([1], dtype=tf.float32)) - x2 = tf.Variable(tf.random_normal([1], dtype=tf.float32)) - y = tf.pow(x1, tf.constant(2.0)) + tf.constant(2.0) * x1 * x2 + \ - tf.constant(3.0) * tf.pow(x2, tf.constant(2.0)) + \ - tf.constant(4.0) * x1 + tf.constant(5.0) * x2 + tf.constant(6.0) - tf.global_variables_initializer().run() - self.assertAllEqual(hessian(y, [x1]).eval(), - np.array([[2.0]])) - self.assertAllEqual(hessian(y, [x2]).eval(), - np.array([[6.0]])) - - def test_hessian_1d(self): - with self.test_session(): - x1 = tf.Variable(tf.random_normal([1], dtype=tf.float32)) - x2 = tf.Variable(tf.random_normal([1], dtype=tf.float32)) - y = tf.pow(x1, tf.constant(2.0)) + tf.constant(2.0) * x1 * x2 + \ - tf.constant(3.0) * tf.pow(x2, tf.constant(2.0)) + \ - tf.constant(4.0) * x1 + tf.constant(5.0) * x2 + tf.constant(6.0) - x3 = tf.Variable(tf.random_normal([3], dtype=tf.float32)) - z = tf.pow(x2, tf.constant(2.0)) + tf.reduce_sum(x3) - tf.global_variables_initializer().run() - self.assertAllEqual(hessian(y, [x1, x2]).eval(), - np.array([[2.0, 2.0], [2.0, 6.0]])) - self.assertAllEqual(hessian(z, [x3]).eval(), - np.zeros([3, 3])) - self.assertAllEqual(hessian(z, [x2, x3]).eval(), - np.diag([2.0, 0.0, 0.0, 0.0])) - - def test_hessian_2d(self): - with self.test_session(): - x1 = tf.Variable(tf.random_normal([3, 2], dtype=tf.float32)) - x2 = tf.Variable(tf.random_normal([2], dtype=tf.float32)) - y = tf.reduce_sum(tf.pow(x1, tf.constant(2.0))) + tf.reduce_sum(x2) - tf.global_variables_initializer().run() - self.assertAllEqual(hessian(y, [x1]).eval(), - np.diag([2.0] * 6)) - self.assertAllEqual(hessian(y, [x1, x2]).eval(), - np.diag([2.0] * 6 + [0.0] * 2)) - - def test_all_finite_raises(self): - with self.test_session(): - x1 = tf.Variable(np.nan * tf.random_normal([1], dtype=tf.float32)) - x2 = tf.Variable(tf.random_normal([1], dtype=tf.float32)) - y = tf.pow(x1, tf.constant(2.0)) + tf.constant(2.0) * x1 * x2 + \ - tf.constant(3.0) * tf.pow(x2, tf.constant(2.0)) + \ - tf.constant(4.0) * x1 + tf.constant(5.0) * x2 + tf.constant(6.0) - tf.global_variables_initializer().run() - with self.assertRaisesOpError('NaN'): - hessian(y, [x1]).eval() - with self.assertRaisesOpError('NaN'): - hessian(y, [x1, x2]).eval() - -if __name__ == '__main__': - tf.test.main()