Skip to content

Commit

Permalink
tf.contrib.distributions.Mixture now takes a list of Distribution ins…
Browse files Browse the repository at this point in the history
…tances

as components.

After some benchmarking, it turns out that gathering parameters during sampling
and creating new Distributions can be substantially slower/consume more memory.

Since we don't need access to the parameters, the interface can be cleaned up.
Change: 133961696
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed Sep 22, 2016
1 parent 47980d7 commit 271260b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from __future__ import print_function

import contextlib
import functools

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -69,9 +68,9 @@ def make_univariate_mixture(batch_shape, num_components):
logits = tf.random_uniform(
list(batch_shape) + [num_components], -1, 1, dtype=tf.float32) - 50.
components = [
(distributions_py.Normal,
{"mu": np.float32(np.random.randn(*list(batch_shape))),
"sigma": np.float32(10 * np.random.rand(*list(batch_shape)))})
distributions_py.Normal(
mu=np.float32(np.random.randn(*list(batch_shape))),
sigma=np.float32(10 * np.random.rand(*list(batch_shape))))
for _ in range(num_components)
]
cat = distributions_py.Categorical(logits, dtype=tf.int32)
Expand All @@ -82,10 +81,10 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape):
logits = tf.random_uniform(
list(batch_shape) + [num_components], -1, 1, dtype=tf.float32) - 50.
components = [
(distributions_py.MultivariateNormalDiag,
{"mu": np.float32(np.random.randn(*list(batch_shape + event_shape))),
"diag_stdev": np.float32(10 * np.random.rand(
*list(batch_shape + event_shape)))})
distributions_py.MultivariateNormalDiag(
mu=np.float32(np.random.randn(*list(batch_shape + event_shape))),
diag_stdev=np.float32(10 * np.random.rand(
*list(batch_shape + event_shape))))
for _ in range(num_components)
]
cat = distributions_py.Categorical(logits, dtype=tf.int32)
Expand Down Expand Up @@ -116,30 +115,30 @@ def testBrokenShapesStatic(self):
r"cat.num_classes != len"):
distributions_py.Mixture(
distributions_py.Categorical([0.1, 0.5]), # 2 classes
[(distributions_py.Normal, {"mu": 1.0, "sigma": 2.0})])
[distributions_py.Normal(mu=1.0, sigma=2.0)])
with self.assertRaisesWithPredicateMatch(
ValueError, r"\(\) and \(2,\) are not compatible"):
# The value error is raised because the batch shapes of the
# Normals are not equal. One is a scalar, the other is a
# vector of size (2,).
distributions_py.Mixture(
distributions_py.Categorical([-0.5, 0.5]), # scalar batch
[(distributions_py.Normal, {"mu": 1.0, "sigma": 2.0}), # scalar dist
(distributions_py.Normal, {"mu": [1.0, 1.0], "sigma": [2.0, 2.0]})])
[distributions_py.Normal(mu=1.0, sigma=2.0), # scalar dist
distributions_py.Normal(mu=[1.0, 1.0], sigma=[2.0, 2.0])])
with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"):
cat_logits = tf.placeholder(shape=[1, None], dtype=tf.int32)
distributions_py.Mixture(
distributions_py.Categorical(cat_logits),
[(distributions_py.Normal, {"mu": [1.0], "sigma": [2.0]})])
[distributions_py.Normal(mu=[1.0], sigma=[2.0])])

def testBrokenShapesDynamic(self):
with self.test_session():
d0_param = tf.placeholder(dtype=tf.float32)
d1_param = tf.placeholder(dtype=tf.float32)
d = distributions_py.Mixture(
distributions_py.Categorical([0.1, 0.2]),
[(distributions_py.Normal, {"mu": d0_param, "sigma": d0_param}),
(distributions_py.Normal, {"mu": d1_param, "sigma": d1_param})],
[distributions_py.Normal(mu=d0_param, sigma=d0_param),
distributions_py.Normal(mu=d1_param, sigma=d1_param)],
validate_args=True)
with self.assertRaisesOpError(r"batch shape must match"):
d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: [1.0]})
Expand All @@ -150,42 +149,24 @@ def testBrokenTypes(self):
with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"):
distributions_py.Mixture(None, [])
cat = distributions_py.Categorical([0.3, 0.2])
# components must be a list of tuples
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
# components must be a list of distributions
with self.assertRaisesWithPredicateMatch(
TypeError, "all .* must be Distribution instances"):
distributions_py.Mixture(cat, [None])
# components tuples must be size 2
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [tuple()])
# components tuples must be size 2
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [(None)])
# components tuples must be of the form (callable, dict)
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [(None, None)])
# components tuples must be size 2
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [(None, None, None)])
# components tuples must be of the form (callable, dict)
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [(lambda x: x, None)])
# components tuples must be of the form (callable, dict)
with self.assertRaisesWithPredicateMatch(TypeError, "tuples of the form"):
distributions_py.Mixture(cat, [(None, {})])
with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"):
distributions_py.Mixture(
cat,
[(distributions_py.Normal, {"mu": [1.0], "sigma": [2.0]}),
(distributions_py.Normal, {"mu": [np.float16(1.0)],
"sigma": [np.float16(2.0)]})])
[distributions_py.Normal(mu=[1.0], sigma=[2.0]),
distributions_py.Normal(mu=[np.float16(1.0)],
sigma=[np.float16(2.0)])])
with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"):
distributions_py.Mixture(distributions_py.Categorical([0.3, 0.2]), None)
with self.assertRaisesWithPredicateMatch(TypeError,
"either be continuous or not"):
distributions_py.Mixture(
cat,
[(distributions_py.Normal, {"mu": [1.0], "sigma": [2.0]}),
(functools.partial(distributions_py.Bernoulli, dtype=tf.float32),
{"logits": [1.0]})])
[distributions_py.Normal(mu=[1.0], sigma=[2.0]),
distributions_py.Bernoulli(dtype=tf.float32, logits=[1.0])])

def testMeanUnivariate(self):
with self.test_session() as sess:
Expand All @@ -196,7 +177,7 @@ def testMeanUnivariate(self):
self.assertEqual(batch_shape, mean.get_shape())

cat_probs = tf.nn.softmax(dist.cat.logits)
dist_means = [d.mean() for d in dist.distributions]
dist_means = [d.mean() for d in dist.components]

mean_value, cat_probs_value, dist_means_value = sess.run(
[mean, cat_probs, dist_means])
Expand All @@ -217,7 +198,7 @@ def testMeanMultivariate(self):
self.assertEqual(batch_shape + (4,), mean.get_shape())

cat_probs = tf.nn.softmax(dist.cat.logits)
dist_means = [d.mean() for d in dist.distributions]
dist_means = [d.mean() for d in dist.components]

mean_value, cat_probs_value, dist_means_value = sess.run(
[mean, cat_probs, dist_means])
Expand All @@ -243,7 +224,7 @@ def testProbScalarUnivariate(self):

self.assertEqual(x.shape, p_x.get_shape())
cat_probs = tf.nn.softmax([dist.cat.logits])[0]
dist_probs = [d.prob(x) for d in dist.distributions]
dist_probs = [d.prob(x) for d in dist.components]

p_x_value, cat_probs_value, dist_probs_value = sess.run(
[p_x, cat_probs, dist_probs])
Expand All @@ -269,7 +250,7 @@ def testProbScalarMultivariate(self):
self.assertEqual(x.shape[:-1], p_x.get_shape())

cat_probs = tf.nn.softmax([dist.cat.logits])[0]
dist_probs = [d.prob(x) for d in dist.distributions]
dist_probs = [d.prob(x) for d in dist.components]

p_x_value, cat_probs_value, dist_probs_value = sess.run(
[p_x, cat_probs, dist_probs])
Expand All @@ -292,7 +273,7 @@ def testProbBatchUnivariate(self):
self.assertEqual(x.shape, p_x.get_shape())

cat_probs = tf.nn.softmax(dist.cat.logits)
dist_probs = [d.prob(x) for d in dist.distributions]
dist_probs = [d.prob(x) for d in dist.components]

p_x_value, cat_probs_value, dist_probs_value = sess.run(
[p_x, cat_probs, dist_probs])
Expand All @@ -318,7 +299,7 @@ def testProbBatchMultivariate(self):
self.assertEqual(x.shape[:-1], p_x.get_shape())

cat_probs = tf.nn.softmax(dist.cat.logits)
dist_probs = [d.prob(x) for d in dist.distributions]
dist_probs = [d.prob(x) for d in dist.components]

p_x_value, cat_probs_value, dist_probs_value = sess.run(
[p_x, cat_probs, dist_probs])
Expand Down Expand Up @@ -430,7 +411,7 @@ def testEntropyLowerBoundMultivariate(self):
self.assertEqual(batch_shape, entropy_lower_bound.get_shape())

cat_probs = tf.nn.softmax(dist.cat.logits)
dist_entropy = [d.entropy() for d in dist.distributions]
dist_entropy = [d.entropy() for d in dist.components]

entropy_lower_bound_value, cat_probs_value, dist_entropy_value = (
sess.run([entropy_lower_bound, cat_probs, dist_entropy]))
Expand Down Expand Up @@ -486,8 +467,7 @@ def create_distribution(batch_size, num_components, num_features):
tf.Variable(np.random.rand(batch_size, num_features))
for _ in range(num_components)]
components = list(
(distributions_py.MultivariateNormalDiag,
{"mu": mu, "diag_stdev": sigma})
distributions_py.MultivariateNormalDiag(mu=mu, diag_stdev=sigma)
for (mu, sigma) in zip(mus, sigmas))
return distributions_py.Mixture(cat, components)

Expand Down Expand Up @@ -524,8 +504,7 @@ def create_distribution(batch_size, num_components, num_features):
psd(np.random.rand(batch_size, num_features, num_features)))
for _ in range(num_components)]
components = list(
(distributions_py.MultivariateNormalFull,
{"mu": mu, "sigma": sigma})
distributions_py.MultivariateNormalFull(mu=mu, sigma=sigma)
for (mu, sigma) in zip(mus, sigmas))
return distributions_py.Mixture(cat, components)

Expand Down
Loading

0 comments on commit 271260b

Please sign in to comment.