diff --git a/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py b/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py index ff9abfea456f49..2eb625487f4cd1 100644 --- a/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py +++ b/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py @@ -113,7 +113,7 @@ def testSplitApplyMerge(self): with self.test_session() as sess: # Use sampling to train REINFORCE - with st.value_type(st.SampleAndReshapeValue(n=1)): + with st.value_type(st.SampleValue()): (route_selection, routing_loss, final_loss) = build_split_apply_merge_model() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py index de5c5c82b82034..5d4fc66c69ae5c 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py @@ -38,7 +38,7 @@ def testPathwiseDerivativeDoesNotAddSurrogateLosses(self): with self.test_session(): mu = [0.0, 0.1, 0.2] sigma = tf.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleAndReshapeValue()): + with st.value_type(st.SampleValue()): prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma)) likelihood = st.StochasticTensor( distributions.Normal(mu=prior, sigma=sigma)) @@ -76,7 +76,7 @@ def testSurrogateLoss(self): with self.test_session() as sess: mu = tf.constant([0.0, 0.1, 0.2]) sigma = tf.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleAndReshapeValue()): + with st.value_type(st.SampleValue()): prior = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma)) likelihood = st.StochasticTensor(NormalNotParam(mu=prior, sigma=sigma)) prior_2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma)) @@ -153,7 +153,7 @@ def testNoSurrogateLoss(self): with self.test_session(): mu = tf.constant([0.0, 0.1, 0.2]) sigma = tf.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleAndReshapeValue()): + with st.value_type(st.SampleValue()): dt = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma), loss_fn=None) self.assertEqual(None, dt.loss(tf.constant([2.0]))) @@ -162,7 +162,7 @@ def testExplicitStochasticTensors(self): with self.test_session() as sess: mu = tf.constant([0.0, 0.1, 0.2]) sigma = tf.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleAndReshapeValue()): + with st.value_type(st.SampleValue()): dt1 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma)) dt2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma)) loss = tf.square(tf.identity(dt1)) + 10. + dt2 diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py index b7bd2adfe8abdf..b73e87ce2830b1 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py @@ -37,19 +37,19 @@ def testConstructionAndValue(self): prior_default = st.StochasticTensor( distributions.Normal(mu=mu, sigma=sigma)) self.assertTrue( - isinstance(prior_default.value_type, st.SampleAndReshapeValue)) + isinstance(prior_default.value_type, st.SampleValue)) prior_0 = st.StochasticTensor( distributions.Normal(mu=mu, sigma=sigma), - dist_value_type=st.SampleAndReshapeValue()) - self.assertTrue(isinstance(prior_0.value_type, st.SampleAndReshapeValue)) + dist_value_type=st.SampleValue()) + self.assertTrue(isinstance(prior_0.value_type, st.SampleValue)) - with st.value_type(st.SampleAndReshapeValue()): + with st.value_type(st.SampleValue()): prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma)) - self.assertTrue(isinstance(prior.value_type, st.SampleAndReshapeValue)) + self.assertTrue(isinstance(prior.value_type, st.SampleValue)) likelihood = st.StochasticTensor( distributions.Normal(mu=prior, sigma=sigma2)) self.assertTrue( - isinstance(likelihood.value_type, st.SampleAndReshapeValue)) + isinstance(likelihood.value_type, st.SampleValue)) coll = tf.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) self.assertEqual(coll, [prior_default, prior_0, prior, likelihood]) @@ -87,15 +87,14 @@ def testMeanValue(self): self.assertAllEqual(prior_mean_val, mu) self.assertAllEqual(prior_mean_val, prior_value_val) - def testSampleAndReshapeValue(self): + def testSampleValueScalar(self): with self.test_session() as sess: mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]] sigma = tf.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]]) - with st.value_type(st.SampleAndReshapeValue()): + with st.value_type(st.SampleValue()): prior_single = st.StochasticTensor( - distributions.Normal( - mu=mu, sigma=sigma)) + distributions.Normal(mu=mu, sigma=sigma)) prior_single_value = prior_single.value() self.assertEqual(prior_single_value.get_shape(), (2, 3)) @@ -103,22 +102,7 @@ def testSampleAndReshapeValue(self): prior_single_value_val = sess.run([prior_single_value])[0] self.assertEqual(prior_single_value_val.shape, (2, 3)) - with st.value_type(st.SampleAndReshapeValue(n=2)): - prior_double = st.StochasticTensor( - distributions.Normal(mu=mu, sigma=sigma)) - - prior_double_value = prior_double.value() - self.assertEqual(prior_double_value.get_shape(), (4, 3)) - - prior_double_value_val = sess.run([prior_double_value])[0] - self.assertEqual(prior_double_value_val.shape, (4, 3)) - - def testSampleValue(self): - with self.test_session() as sess: - mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]] - sigma = tf.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]]) - - with st.value_type(st.SampleValue()): + with st.value_type(st.SampleValue(1)): prior_single = st.StochasticTensor( distributions.Normal(mu=mu, sigma=sigma)) self.assertTrue(isinstance(prior_single.value_type, st.SampleValue)) @@ -129,7 +113,7 @@ def testSampleValue(self): prior_single_value_val = sess.run([prior_single_value])[0] self.assertEqual(prior_single_value_val.shape, (1, 2, 3)) - with st.value_type(st.SampleValue(n=2)): + with st.value_type(st.SampleValue(2)): prior_double = st.StochasticTensor( distributions.Normal(mu=mu, sigma=sigma)) @@ -182,7 +166,7 @@ class ValueTypeTest(tf.test.TestCase): def testValueType(self): type_mean = st.MeanValue() - type_reshape = st.SampleAndReshapeValue() + type_reshape = st.SampleValue() type_full = st.SampleValue() with st.value_type(type_mean): self.assertEqual(st.get_current_value_type(), type_mean) diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py index eaee3344e5d21e..e52c81740dfe4d 100644 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py +++ b/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py @@ -31,7 +31,6 @@ @@MeanValue @@SampleValue -@@SampleAndReshapeValue @@value_type @@get_current_value_type @@ -51,7 +50,6 @@ from tensorflow.contrib import distributions from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators as sge from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops STOCHASTIC_TENSOR_COLLECTION = "_stochastic_tensor_collection_" @@ -122,8 +120,7 @@ def _tensor_conversion_function(v, dtype=None, name=None, as_ref=False): class _StochasticValueType(object): """Interface for the ValueType classes. - This is the base class for MeanValue, SampleValue, SampleAndReshapeValue, - and their descendants. + This is the base class for MeanValue, SampleValue, and their descendants. """ def pushed_above(self, unused_value_type): @@ -155,89 +152,53 @@ def stop_gradient(self): class SampleValue(_StochasticValueType): - """Draw n samples along a new outer dimension. + """Draw samples, possibly adding new outer dimensions along the way. - This ValueType draws `n` samples from StochasticTensors run within its - context, increasing the rank by one along a new outer dimension. + This ValueType draws samples from StochasticTensors run within its + context, increasing the rank according to the requested shape. - Example: + Examples: ```python mu = tf.zeros((2,3)) sigma = tf.ones((2, 3)) - with sg.value_type(sg.SampleValue(n=4)): + with sg.value_type(sg.SampleValue()): st = sg.StochasticTensor( distributions.Normal, mu=mu, sigma=sigma) - # draws 4 samples each with shape (2, 3) and concatenates - assertEqual(st.value().get_shape(), (4, 2, 3)) + # draws 1 sample and does not reshape + assertEqual(st.value().get_shape(), (2, 3)) ``` - """ - - def __init__(self, n=1, stop_gradient=False): - """Sample `n` times and concatenate along a new outer dimension. - - Args: - n: A python integer or int32 tensor. The number of samples to take. - stop_gradient: If `True`, StochasticTensors' values are wrapped in - `stop_gradient`, to avoid backpropagation through. - """ - self._n = n - self._stop_gradient = stop_gradient - - @property - def n(self): - return self._n - - @property - def stop_gradient(self): - return self._stop_gradient - - -class SampleAndReshapeValue(_StochasticValueType): - """Ask the StochasticTensor for n samples and reshape the result. - - Sampling from a StochasticTensor increases the rank of the value by 1 - (because each sample represents a new outer dimension). - - This ValueType requests `n` samples from StochasticTensors run within its - context that the outer two dimensions are reshaped to intermix the samples - with the outermost (usually batch) dimension. - - Example: ```python - # mu and sigma are both shaped (2, 3) - mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]] - sigma = tf.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]]) - - with sg.value_type(sg.SampleAndReshapeValue(n=2)): + mu = tf.zeros((2,3)) + sigma = tf.ones((2, 3)) + with sg.value_type(sg.SampleValue(4)): st = sg.StochasticTensor( - distributions.Normal, mu=mu, sigma=sigma) - - # sample(2) creates a (2, 2, 3) tensor, and the two outermost dimensions - # are reshaped into one: the final value is a (4, 3) tensor. - st_value = st.value() - assertEqual(st_value.get_shape(), (4, 3)) - - st_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0] - assertEqual(st_value_val.shape, (4, 3)) + distributions.Normal, mu=mu, sigma=sigma) + # draws 4 samples each with shape (2, 3) and concatenates + assertEqual(st.value().get_shape(), (4, 2, 3)) ``` """ - def __init__(self, n=1, stop_gradient=False): - """Sample `n` times and reshape the outer 2 axes so rank does not change. + def __init__(self, shape=(), stop_gradient=False): + """Sample according to shape. + + For the given StochasticTensor `st` using this value type, + the shape of `st.value()` will match that of + `st.distribution.sample(shape)`. Args: - n: A python integer or int32 tensor. The number of samples to take. + shape: A shape tuple or int32 tensor. The sample shape. + Default is a scalar: take one sample and do not change the size. stop_gradient: If `True`, StochasticTensors' values are wrapped in `stop_gradient`, to avoid backpropagation through. """ - self._n = n + self._shape = shape self._stop_gradient = stop_gradient @property - def n(self): - return self._n + def shape(self): + return self._shape @property def stop_gradient(self): @@ -267,7 +228,7 @@ def value_type(dist_value_type): in a `stop_gradients` call to disable any possible backpropagation. Args: - dist_value_type: An instance of `MeanValue`, `SampleAndReshapeValue`, or + dist_value_type: An instance of `MeanValue`, `SampleValue`, or any other stochastic value type. Yields: @@ -317,7 +278,7 @@ def __init__(self, `StochasticTensor` is backed by the `dist` distribution and its `value` method will return the same value each time it is called. What `value` is returned is controlled by the `dist_value_type` (defaults to - `SampleAndReshapeValue`). + `SampleValue`). Some distributions' sample functions are not differentiable (e.g. a sample from a discrete distribution like a Bernoulli) and so to differentiate @@ -356,7 +317,7 @@ def __init__(self, try: self._value_type = get_current_value_type() except NoValueTypeSetError: - self._value_type = SampleAndReshapeValue() + self._value_type = SampleValue() else: # We want to enforce a value type here, but use the value_type() # context manager to enforce some error checking. @@ -388,26 +349,7 @@ def _create_value(self): if isinstance(self._value_type, MeanValue): value_tensor = self._dist.mean() elif isinstance(self._value_type, SampleValue): - value_tensor = self._dist.sample(self._value_type.n) - elif isinstance(self._value_type, SampleAndReshapeValue): - if self._value_type.n == 1: - value_tensor = self._dist.sample() - else: - samples = self._dist.sample(self._value_type.n) - samples_shape = array_ops.shape(samples) - samples_static_shape = samples.get_shape() - new_batch_size = samples_shape[0] * samples_shape[1] - value_tensor = array_ops.reshape( - samples, array_ops.concat(0, ([new_batch_size], samples_shape[2:]))) - if samples_static_shape.ndims is not None: - # Update the static shape for shape inference purposes - shape_list = samples_static_shape.as_list() - new_shape = tensor_shape.vector( - shape_list[0] * shape_list[1] - if shape_list[0] is not None and shape_list[1] is not None - else None) - new_shape = new_shape.concatenate(samples_static_shape[2:]) - value_tensor.set_shape(new_shape) + value_tensor = self._dist.sample(self._value_type.shape) else: raise TypeError( "Unrecognized Distribution Value Type: %s", self._value_type) @@ -462,7 +404,6 @@ def loss(self, final_loss, name="Loss"): with ops.name_scope(self.name, values=[final_loss]): with ops.name_scope(name): if (self._value_type.stop_gradient or - isinstance(self._value_type, SampleAndReshapeValue) or isinstance(self._value_type, SampleValue)): return self._loss_fn(self, self._value, final_loss) elif isinstance(self._value_type, MeanValue): @@ -530,7 +471,6 @@ def loss(self, final_loss, name=None): "ObservedStochasticTensor", "MeanValue", "SampleValue", - "SampleAndReshapeValue", "value_type", "get_current_value_type", ]