Skip to content

Commit

Permalink
bayesflow: replace SampleAndReshapeValue with SampleValue()
Browse files Browse the repository at this point in the history
Change: 138649779
  • Loading branch information
ebrevdo authored and tensorflower-gardener committed Nov 9, 2016
1 parent d9da972 commit c66f878
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def testSplitApplyMerge(self):

with self.test_session() as sess:
# Use sampling to train REINFORCE
with st.value_type(st.SampleAndReshapeValue(n=1)):
with st.value_type(st.SampleValue()):
(route_selection,
routing_loss,
final_loss) = build_split_apply_merge_model()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def testPathwiseDerivativeDoesNotAddSurrogateLosses(self):
with self.test_session():
mu = [0.0, 0.1, 0.2]
sigma = tf.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleAndReshapeValue()):
with st.value_type(st.SampleValue()):
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
likelihood = st.StochasticTensor(
distributions.Normal(mu=prior, sigma=sigma))
Expand Down Expand Up @@ -76,7 +76,7 @@ def testSurrogateLoss(self):
with self.test_session() as sess:
mu = tf.constant([0.0, 0.1, 0.2])
sigma = tf.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleAndReshapeValue()):
with st.value_type(st.SampleValue()):
prior = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
likelihood = st.StochasticTensor(NormalNotParam(mu=prior, sigma=sigma))
prior_2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
Expand Down Expand Up @@ -153,7 +153,7 @@ def testNoSurrogateLoss(self):
with self.test_session():
mu = tf.constant([0.0, 0.1, 0.2])
sigma = tf.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleAndReshapeValue()):
with st.value_type(st.SampleValue()):
dt = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma),
loss_fn=None)
self.assertEqual(None, dt.loss(tf.constant([2.0])))
Expand All @@ -162,7 +162,7 @@ def testExplicitStochasticTensors(self):
with self.test_session() as sess:
mu = tf.constant([0.0, 0.1, 0.2])
sigma = tf.constant([1.1, 1.2, 1.3])
with st.value_type(st.SampleAndReshapeValue()):
with st.value_type(st.SampleValue()):
dt1 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
dt2 = st.StochasticTensor(NormalNotParam(mu=mu, sigma=sigma))
loss = tf.square(tf.identity(dt1)) + 10. + dt2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ def testConstructionAndValue(self):
prior_default = st.StochasticTensor(
distributions.Normal(mu=mu, sigma=sigma))
self.assertTrue(
isinstance(prior_default.value_type, st.SampleAndReshapeValue))
isinstance(prior_default.value_type, st.SampleValue))
prior_0 = st.StochasticTensor(
distributions.Normal(mu=mu, sigma=sigma),
dist_value_type=st.SampleAndReshapeValue())
self.assertTrue(isinstance(prior_0.value_type, st.SampleAndReshapeValue))
dist_value_type=st.SampleValue())
self.assertTrue(isinstance(prior_0.value_type, st.SampleValue))

with st.value_type(st.SampleAndReshapeValue()):
with st.value_type(st.SampleValue()):
prior = st.StochasticTensor(distributions.Normal(mu=mu, sigma=sigma))
self.assertTrue(isinstance(prior.value_type, st.SampleAndReshapeValue))
self.assertTrue(isinstance(prior.value_type, st.SampleValue))
likelihood = st.StochasticTensor(
distributions.Normal(mu=prior, sigma=sigma2))
self.assertTrue(
isinstance(likelihood.value_type, st.SampleAndReshapeValue))
isinstance(likelihood.value_type, st.SampleValue))

coll = tf.get_collection(st.STOCHASTIC_TENSOR_COLLECTION)
self.assertEqual(coll, [prior_default, prior_0, prior, likelihood])
Expand Down Expand Up @@ -87,38 +87,22 @@ def testMeanValue(self):
self.assertAllEqual(prior_mean_val, mu)
self.assertAllEqual(prior_mean_val, prior_value_val)

def testSampleAndReshapeValue(self):
def testSampleValueScalar(self):
with self.test_session() as sess:
mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]]
sigma = tf.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]])

with st.value_type(st.SampleAndReshapeValue()):
with st.value_type(st.SampleValue()):
prior_single = st.StochasticTensor(
distributions.Normal(
mu=mu, sigma=sigma))
distributions.Normal(mu=mu, sigma=sigma))

prior_single_value = prior_single.value()
self.assertEqual(prior_single_value.get_shape(), (2, 3))

prior_single_value_val = sess.run([prior_single_value])[0]
self.assertEqual(prior_single_value_val.shape, (2, 3))

with st.value_type(st.SampleAndReshapeValue(n=2)):
prior_double = st.StochasticTensor(
distributions.Normal(mu=mu, sigma=sigma))

prior_double_value = prior_double.value()
self.assertEqual(prior_double_value.get_shape(), (4, 3))

prior_double_value_val = sess.run([prior_double_value])[0]
self.assertEqual(prior_double_value_val.shape, (4, 3))

def testSampleValue(self):
with self.test_session() as sess:
mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]]
sigma = tf.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]])

with st.value_type(st.SampleValue()):
with st.value_type(st.SampleValue(1)):
prior_single = st.StochasticTensor(
distributions.Normal(mu=mu, sigma=sigma))
self.assertTrue(isinstance(prior_single.value_type, st.SampleValue))
Expand All @@ -129,7 +113,7 @@ def testSampleValue(self):
prior_single_value_val = sess.run([prior_single_value])[0]
self.assertEqual(prior_single_value_val.shape, (1, 2, 3))

with st.value_type(st.SampleValue(n=2)):
with st.value_type(st.SampleValue(2)):
prior_double = st.StochasticTensor(
distributions.Normal(mu=mu, sigma=sigma))

Expand Down Expand Up @@ -182,7 +166,7 @@ class ValueTypeTest(tf.test.TestCase):

def testValueType(self):
type_mean = st.MeanValue()
type_reshape = st.SampleAndReshapeValue()
type_reshape = st.SampleValue()
type_full = st.SampleValue()
with st.value_type(type_mean):
self.assertEqual(st.get_current_value_type(), type_mean)
Expand Down
118 changes: 29 additions & 89 deletions tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
@@MeanValue
@@SampleValue
@@SampleAndReshapeValue
@@value_type
@@get_current_value_type
Expand All @@ -51,7 +50,6 @@
from tensorflow.contrib import distributions
from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators as sge
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops

STOCHASTIC_TENSOR_COLLECTION = "_stochastic_tensor_collection_"
Expand Down Expand Up @@ -122,8 +120,7 @@ def _tensor_conversion_function(v, dtype=None, name=None, as_ref=False):
class _StochasticValueType(object):
"""Interface for the ValueType classes.
This is the base class for MeanValue, SampleValue, SampleAndReshapeValue,
and their descendants.
This is the base class for MeanValue, SampleValue, and their descendants.
"""

def pushed_above(self, unused_value_type):
Expand Down Expand Up @@ -155,89 +152,53 @@ def stop_gradient(self):


class SampleValue(_StochasticValueType):
"""Draw n samples along a new outer dimension.
"""Draw samples, possibly adding new outer dimensions along the way.
This ValueType draws `n` samples from StochasticTensors run within its
context, increasing the rank by one along a new outer dimension.
This ValueType draws samples from StochasticTensors run within its
context, increasing the rank according to the requested shape.
Example:
Examples:
```python
mu = tf.zeros((2,3))
sigma = tf.ones((2, 3))
with sg.value_type(sg.SampleValue(n=4)):
with sg.value_type(sg.SampleValue()):
st = sg.StochasticTensor(
distributions.Normal, mu=mu, sigma=sigma)
# draws 4 samples each with shape (2, 3) and concatenates
assertEqual(st.value().get_shape(), (4, 2, 3))
# draws 1 sample and does not reshape
assertEqual(st.value().get_shape(), (2, 3))
```
"""

def __init__(self, n=1, stop_gradient=False):
"""Sample `n` times and concatenate along a new outer dimension.
Args:
n: A python integer or int32 tensor. The number of samples to take.
stop_gradient: If `True`, StochasticTensors' values are wrapped in
`stop_gradient`, to avoid backpropagation through.
"""
self._n = n
self._stop_gradient = stop_gradient

@property
def n(self):
return self._n

@property
def stop_gradient(self):
return self._stop_gradient


class SampleAndReshapeValue(_StochasticValueType):
"""Ask the StochasticTensor for n samples and reshape the result.
Sampling from a StochasticTensor increases the rank of the value by 1
(because each sample represents a new outer dimension).
This ValueType requests `n` samples from StochasticTensors run within its
context that the outer two dimensions are reshaped to intermix the samples
with the outermost (usually batch) dimension.
Example:
```python
# mu and sigma are both shaped (2, 3)
mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]]
sigma = tf.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]])
with sg.value_type(sg.SampleAndReshapeValue(n=2)):
mu = tf.zeros((2,3))
sigma = tf.ones((2, 3))
with sg.value_type(sg.SampleValue(4)):
st = sg.StochasticTensor(
distributions.Normal, mu=mu, sigma=sigma)
# sample(2) creates a (2, 2, 3) tensor, and the two outermost dimensions
# are reshaped into one: the final value is a (4, 3) tensor.
st_value = st.value()
assertEqual(st_value.get_shape(), (4, 3))
st_value_val = sess.run([st_value])[0] # or e.g. run([tf.identity(st)])[0]
assertEqual(st_value_val.shape, (4, 3))
distributions.Normal, mu=mu, sigma=sigma)
# draws 4 samples each with shape (2, 3) and concatenates
assertEqual(st.value().get_shape(), (4, 2, 3))
```
"""

def __init__(self, n=1, stop_gradient=False):
"""Sample `n` times and reshape the outer 2 axes so rank does not change.
def __init__(self, shape=(), stop_gradient=False):
"""Sample according to shape.
For the given StochasticTensor `st` using this value type,
the shape of `st.value()` will match that of
`st.distribution.sample(shape)`.
Args:
n: A python integer or int32 tensor. The number of samples to take.
shape: A shape tuple or int32 tensor. The sample shape.
Default is a scalar: take one sample and do not change the size.
stop_gradient: If `True`, StochasticTensors' values are wrapped in
`stop_gradient`, to avoid backpropagation through.
"""
self._n = n
self._shape = shape
self._stop_gradient = stop_gradient

@property
def n(self):
return self._n
def shape(self):
return self._shape

@property
def stop_gradient(self):
Expand Down Expand Up @@ -267,7 +228,7 @@ def value_type(dist_value_type):
in a `stop_gradients` call to disable any possible backpropagation.
Args:
dist_value_type: An instance of `MeanValue`, `SampleAndReshapeValue`, or
dist_value_type: An instance of `MeanValue`, `SampleValue`, or
any other stochastic value type.
Yields:
Expand Down Expand Up @@ -317,7 +278,7 @@ def __init__(self,
`StochasticTensor` is backed by the `dist` distribution and its `value`
method will return the same value each time it is called. What `value` is
returned is controlled by the `dist_value_type` (defaults to
`SampleAndReshapeValue`).
`SampleValue`).
Some distributions' sample functions are not differentiable (e.g. a sample
from a discrete distribution like a Bernoulli) and so to differentiate
Expand Down Expand Up @@ -356,7 +317,7 @@ def __init__(self,
try:
self._value_type = get_current_value_type()
except NoValueTypeSetError:
self._value_type = SampleAndReshapeValue()
self._value_type = SampleValue()
else:
# We want to enforce a value type here, but use the value_type()
# context manager to enforce some error checking.
Expand Down Expand Up @@ -388,26 +349,7 @@ def _create_value(self):
if isinstance(self._value_type, MeanValue):
value_tensor = self._dist.mean()
elif isinstance(self._value_type, SampleValue):
value_tensor = self._dist.sample(self._value_type.n)
elif isinstance(self._value_type, SampleAndReshapeValue):
if self._value_type.n == 1:
value_tensor = self._dist.sample()
else:
samples = self._dist.sample(self._value_type.n)
samples_shape = array_ops.shape(samples)
samples_static_shape = samples.get_shape()
new_batch_size = samples_shape[0] * samples_shape[1]
value_tensor = array_ops.reshape(
samples, array_ops.concat(0, ([new_batch_size], samples_shape[2:])))
if samples_static_shape.ndims is not None:
# Update the static shape for shape inference purposes
shape_list = samples_static_shape.as_list()
new_shape = tensor_shape.vector(
shape_list[0] * shape_list[1]
if shape_list[0] is not None and shape_list[1] is not None
else None)
new_shape = new_shape.concatenate(samples_static_shape[2:])
value_tensor.set_shape(new_shape)
value_tensor = self._dist.sample(self._value_type.shape)
else:
raise TypeError(
"Unrecognized Distribution Value Type: %s", self._value_type)
Expand Down Expand Up @@ -462,7 +404,6 @@ def loss(self, final_loss, name="Loss"):
with ops.name_scope(self.name, values=[final_loss]):
with ops.name_scope(name):
if (self._value_type.stop_gradient or
isinstance(self._value_type, SampleAndReshapeValue) or
isinstance(self._value_type, SampleValue)):
return self._loss_fn(self, self._value, final_loss)
elif isinstance(self._value_type, MeanValue):
Expand Down Expand Up @@ -530,7 +471,6 @@ def loss(self, final_loss, name=None):
"ObservedStochasticTensor",
"MeanValue",
"SampleValue",
"SampleAndReshapeValue",
"value_type",
"get_current_value_type",
]

0 comments on commit c66f878

Please sign in to comment.