Skip to content

Commit

Permalink
Let Inference work with tf.Tensor latent variables and observed varia…
Browse files Browse the repository at this point in the history
…bles (#488)

* let Inference work with tf.Tensor latent_vars and data

* re-order control flow logic

* update docstring of inference subclasses

* update tests
  • Loading branch information
dustinvtran committed Feb 28, 2017
1 parent 06c4a57 commit f20c8f8
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 42 deletions.
62 changes: 32 additions & 30 deletions edward/inferences/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,15 @@ class Inference(object):
Attributes
----------
latent_vars : dict of RandomVariable to RandomVariable
Collection of random variables to perform inference on. Each
random variable is binded to another random variable; the latter
will infer the former conditional on data.
latent_vars : dict
Collection of latent variables (of type ``RandomVariable`` or
``tf.Tensor``) to perform inference on. Each random variable is
binded to another random variable; the latter will infer the
former conditional on data.
data : dict
Data dictionary whose values may vary at each session run.
Data dictionary which binds observed variables (of type
``RandomVariable`` or ``tf.Tensor``) to their realizations (of
type ``tf.Tensor``).
model_wrapper : ed.Model or None
An optional wrapper for the probability model. If specified, the
random variables in ``latent_vars``' dictionary keys are strings
Expand All @@ -40,10 +43,11 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None):
Parameters
----------
latent_vars : dict of RandomVariable to RandomVariable, optional
Collection of random variables to perform inference on. Each
random variable is binded to another random variable; the latter
will infer the former conditional on data.
latent_vars : dict, optional
Collection of latent variables (of type ``RandomVariable`` or
``tf.Tensor``) to perform inference on. Each random variable is
binded to another random variable; the latter will infer the
former conditional on data.
data : dict, optional
Data dictionary which binds observed variables (of type
``RandomVariable`` or ``tf.Tensor``) to their realizations (of
Expand Down Expand Up @@ -92,10 +96,9 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None):
raise TypeError()

for key, value in six.iteritems(latent_vars):
if isinstance(value, RandomVariable):
if isinstance(key, RandomVariable):
if not key.value().get_shape().is_compatible_with(
value.value().get_shape()):
if isinstance(value, RandomVariable) or isinstance(value, tf.Tensor):
if isinstance(key, RandomVariable) or isinstance(key, tf.Tensor):
if not key.get_shape().is_compatible_with(value.get_shape()):
raise TypeError("Latent variable bindings do not have same shape.")
elif not isinstance(key, str):
raise TypeError("Latent variable key has an invalid type.")
Expand All @@ -119,16 +122,22 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None):
else:
self.data = {}
for key, value in six.iteritems(data):
if isinstance(key, RandomVariable):
if isinstance(key, RandomVariable) or \
(isinstance(key, tf.Tensor) and "Placeholder" not in key.op.type):
if isinstance(value, tf.Tensor):
if not key.value().get_shape().is_compatible_with(
value.get_shape()):
if not key.get_shape().is_compatible_with(value.get_shape()):
raise TypeError("Observed variable bindings do not have same "
"shape.")

self.data[key] = tf.cast(value, tf.float32)
elif isinstance(value, RandomVariable):
if not key.get_shape().is_compatible_with(value.get_shape()):
raise TypeError("Observed variable bindings do not have same "
"shape.")

self.data[key] = value
elif isinstance(value, np.ndarray):
if not key.value().get_shape().is_compatible_with(value.shape):
if not key.get_shape().is_compatible_with(value.shape):
raise TypeError("Observed variable bindings do not have same "
"shape.")

Expand All @@ -144,13 +153,6 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None):
var = tf.Variable(ph, trainable=False, collections=[])
self.data[key] = var
sess.run(var.initializer, {ph: value})
elif isinstance(value, RandomVariable):
if not key.value().get_shape().is_compatible_with(
value.value().get_shape()):
raise TypeError("Observed variable bindings do not have same "
"shape.")

self.data[key] = value
elif isinstance(value, np.number):
if np.issubdtype(value.dtype, np.float):
ph_type = tf.float32
Expand All @@ -177,6 +179,12 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None):
sess.run(var.initializer, {ph: int(value)})
else:
raise TypeError("Data value has an invalid type.")
elif isinstance(key, tf.Tensor):
if isinstance(value, RandomVariable):
raise TypeError("Data placeholder cannot be bound to a "
"RandomVariable.")

self.data[key] = value
elif isinstance(key, str):
if isinstance(value, tf.Tensor):
self.data[key] = tf.cast(value, tf.float32)
Expand All @@ -190,12 +198,6 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None):
elif (have_theano and
isinstance(key, theano.tensor.sharedvar.TensorSharedVariable)):
self.data[key] = value
elif isinstance(key, tf.Tensor):
if isinstance(value, RandomVariable):
raise TypeError("Data placeholder cannot be bound to a "
"RandomVariable.")

self.data[key] = value
else:
raise TypeError("Data key has an invalid type.")

Expand Down
2 changes: 1 addition & 1 deletion edward/inferences/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None):
list, each random variable will be implictly optimized
using a ``PointMass`` random variable that is defined
internally (with unconstrained support). If dictionary, each
random variable must be a ``PointMass`` random variable.
value in the dictionary must be a ``PointMass`` random variable.
Examples
--------
Expand Down
14 changes: 7 additions & 7 deletions edward/inferences/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None):
Parameters
----------
latent_vars : list of RandomVariable or
dict of RandomVariable to RandomVariable
Collection of random variables to perform inference on. If
list, each random variable will be implictly approximated
using a ``Empirical`` random variable that is defined
internally (with unconstrained support). If dictionary, each
random variable must be a ``Empirical`` random variable.
latent_vars : list or dict, optional
Collection of random variables (of type ``RandomVariable`` or
``tf.Tensor``) to perform inference on. If list, each random
variable will be approximated using a ``Empirical`` random
variable that is defined internally (with unconstrained
support). If dictionary, each value in the dictionary must be a
``Empirical`` random variable.
data : dict, optional
Data dictionary which binds observed variables (of type
``RandomVariable`` or ``tf.Tensor``) to their realizations (of
Expand Down
5 changes: 3 additions & 2 deletions examples/gan_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ def plot(samples):
data={x: x_ph}, discriminator=discriminative_network)
inference.initialize(
optimizer=optimizer, optimizer_d=optimizer,
n_iter=15000 * 6, n_print=1000 * 6)
n_iter=15000, n_print=1000)

sess = ed.get_session()
tf.global_variables_initializer().run()

idx = np.random.randint(M, size=16)
i = 0
for t in range(inference.n_iter):
if (t * 6) % inference.n_print == 0:
if t % inference.n_print == 0:
samples = sess.run(x)
samples = samples[idx, ]

Expand All @@ -102,4 +102,5 @@ def plot(samples):
info_dict = inference.update(feed_dict={x_ph: x_batch}, variables="Gen")
# note: not printing discriminative objective; ``info_dict`` above
# does not store it since updating only "Gen"
info_dict['t'] = info_dict['t'] // 6 # say set of 6 updates is 1 iteration
inference.print_progress(info_dict)
5 changes: 3 additions & 2 deletions tests/test-inferences/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def test_latent_vars(self):
qmu_misshape = Normal(mu=tf.constant([0.0]), sigma=tf.constant([1.0]))

ed.Inference({mu: qmu})
ed.Inference({mu: tf.constant(0.0)})
ed.Inference({tf.constant(0.0): qmu})
self.assertRaises(TypeError, ed.Inference, {mu: '5'})
self.assertRaises(TypeError, ed.Inference, {mu: tf.constant(0.0)})
self.assertRaises(TypeError, ed.Inference, {tf.constant(0.0): qmu})
self.assertRaises(TypeError, ed.Inference, {mu: qmu_misshape})

def test_data(self):
Expand All @@ -49,6 +49,7 @@ def test_data(self):
ed.Inference(data={x: False}) # converted to `int`
ed.Inference(data={x: x_ph})
ed.Inference(data={x: qx})
ed.Inference(data={2.0 * x: tf.constant(0.0)})
self.assertRaises(TypeError, ed.Inference, data={5: tf.constant(0.0)})
self.assertRaises(TypeError, ed.Inference, data={x: tf.zeros(5)})
self.assertRaises(TypeError, ed.Inference, data={x_ph: x})
Expand Down

0 comments on commit f20c8f8

Please sign in to comment.