diff --git a/edward/inferences/inference.py b/edward/inferences/inference.py index 72ae19d9a..b02488627 100644 --- a/edward/inferences/inference.py +++ b/edward/inferences/inference.py @@ -24,12 +24,15 @@ class Inference(object): Attributes ---------- - latent_vars : dict of RandomVariable to RandomVariable - Collection of random variables to perform inference on. Each - random variable is binded to another random variable; the latter - will infer the former conditional on data. + latent_vars : dict + Collection of latent variables (of type ``RandomVariable`` or + ``tf.Tensor``) to perform inference on. Each random variable is + binded to another random variable; the latter will infer the + former conditional on data. data : dict - Data dictionary whose values may vary at each session run. + Data dictionary which binds observed variables (of type + ``RandomVariable`` or ``tf.Tensor``) to their realizations (of + type ``tf.Tensor``). model_wrapper : ed.Model or None An optional wrapper for the probability model. If specified, the random variables in ``latent_vars``' dictionary keys are strings @@ -40,10 +43,11 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None): Parameters ---------- - latent_vars : dict of RandomVariable to RandomVariable, optional - Collection of random variables to perform inference on. Each - random variable is binded to another random variable; the latter - will infer the former conditional on data. + latent_vars : dict, optional + Collection of latent variables (of type ``RandomVariable`` or + ``tf.Tensor``) to perform inference on. Each random variable is + binded to another random variable; the latter will infer the + former conditional on data. data : dict, optional Data dictionary which binds observed variables (of type ``RandomVariable`` or ``tf.Tensor``) to their realizations (of @@ -92,10 +96,9 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None): raise TypeError() for key, value in six.iteritems(latent_vars): - if isinstance(value, RandomVariable): - if isinstance(key, RandomVariable): - if not key.value().get_shape().is_compatible_with( - value.value().get_shape()): + if isinstance(value, RandomVariable) or isinstance(value, tf.Tensor): + if isinstance(key, RandomVariable) or isinstance(key, tf.Tensor): + if not key.get_shape().is_compatible_with(value.get_shape()): raise TypeError("Latent variable bindings do not have same shape.") elif not isinstance(key, str): raise TypeError("Latent variable key has an invalid type.") @@ -119,16 +122,22 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None): else: self.data = {} for key, value in six.iteritems(data): - if isinstance(key, RandomVariable): + if isinstance(key, RandomVariable) or \ + (isinstance(key, tf.Tensor) and "Placeholder" not in key.op.type): if isinstance(value, tf.Tensor): - if not key.value().get_shape().is_compatible_with( - value.get_shape()): + if not key.get_shape().is_compatible_with(value.get_shape()): raise TypeError("Observed variable bindings do not have same " "shape.") self.data[key] = tf.cast(value, tf.float32) + elif isinstance(value, RandomVariable): + if not key.get_shape().is_compatible_with(value.get_shape()): + raise TypeError("Observed variable bindings do not have same " + "shape.") + + self.data[key] = value elif isinstance(value, np.ndarray): - if not key.value().get_shape().is_compatible_with(value.shape): + if not key.get_shape().is_compatible_with(value.shape): raise TypeError("Observed variable bindings do not have same " "shape.") @@ -144,13 +153,6 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None): var = tf.Variable(ph, trainable=False, collections=[]) self.data[key] = var sess.run(var.initializer, {ph: value}) - elif isinstance(value, RandomVariable): - if not key.value().get_shape().is_compatible_with( - value.value().get_shape()): - raise TypeError("Observed variable bindings do not have same " - "shape.") - - self.data[key] = value elif isinstance(value, np.number): if np.issubdtype(value.dtype, np.float): ph_type = tf.float32 @@ -177,6 +179,12 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None): sess.run(var.initializer, {ph: int(value)}) else: raise TypeError("Data value has an invalid type.") + elif isinstance(key, tf.Tensor): + if isinstance(value, RandomVariable): + raise TypeError("Data placeholder cannot be bound to a " + "RandomVariable.") + + self.data[key] = value elif isinstance(key, str): if isinstance(value, tf.Tensor): self.data[key] = tf.cast(value, tf.float32) @@ -190,12 +198,6 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None): elif (have_theano and isinstance(key, theano.tensor.sharedvar.TensorSharedVariable)): self.data[key] = value - elif isinstance(key, tf.Tensor): - if isinstance(value, RandomVariable): - raise TypeError("Data placeholder cannot be bound to a " - "RandomVariable.") - - self.data[key] = value else: raise TypeError("Data key has an invalid type.") diff --git a/edward/inferences/map.py b/edward/inferences/map.py index 02436e592..52850c540 100644 --- a/edward/inferences/map.py +++ b/edward/inferences/map.py @@ -55,7 +55,7 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None): list, each random variable will be implictly optimized using a ``PointMass`` random variable that is defined internally (with unconstrained support). If dictionary, each - random variable must be a ``PointMass`` random variable. + value in the dictionary must be a ``PointMass`` random variable. Examples -------- diff --git a/edward/inferences/monte_carlo.py b/edward/inferences/monte_carlo.py index cbc217e4b..6baf2959b 100644 --- a/edward/inferences/monte_carlo.py +++ b/edward/inferences/monte_carlo.py @@ -19,13 +19,13 @@ def __init__(self, latent_vars=None, data=None, model_wrapper=None): Parameters ---------- - latent_vars : list of RandomVariable or - dict of RandomVariable to RandomVariable - Collection of random variables to perform inference on. If - list, each random variable will be implictly approximated - using a ``Empirical`` random variable that is defined - internally (with unconstrained support). If dictionary, each - random variable must be a ``Empirical`` random variable. + latent_vars : list or dict, optional + Collection of random variables (of type ``RandomVariable`` or + ``tf.Tensor``) to perform inference on. If list, each random + variable will be approximated using a ``Empirical`` random + variable that is defined internally (with unconstrained + support). If dictionary, each value in the dictionary must be a + ``Empirical`` random variable. data : dict, optional Data dictionary which binds observed variables (of type ``RandomVariable`` or ``tf.Tensor``) to their realizations (of diff --git a/examples/gan_wasserstein.py b/examples/gan_wasserstein.py index 89ceabb3d..a20f2b808 100644 --- a/examples/gan_wasserstein.py +++ b/examples/gan_wasserstein.py @@ -77,7 +77,7 @@ def plot(samples): data={x: x_ph}, discriminator=discriminative_network) inference.initialize( optimizer=optimizer, optimizer_d=optimizer, - n_iter=15000 * 6, n_print=1000 * 6) + n_iter=15000, n_print=1000) sess = ed.get_session() tf.global_variables_initializer().run() @@ -85,7 +85,7 @@ def plot(samples): idx = np.random.randint(M, size=16) i = 0 for t in range(inference.n_iter): - if (t * 6) % inference.n_print == 0: + if t % inference.n_print == 0: samples = sess.run(x) samples = samples[idx, ] @@ -102,4 +102,5 @@ def plot(samples): info_dict = inference.update(feed_dict={x_ph: x_batch}, variables="Gen") # note: not printing discriminative objective; ``info_dict`` above # does not store it since updating only "Gen" + info_dict['t'] = info_dict['t'] // 6 # say set of 6 updates is 1 iteration inference.print_progress(info_dict) diff --git a/tests/test-inferences/test_inference.py b/tests/test-inferences/test_inference.py index 4aefe1c24..0826a068d 100644 --- a/tests/test-inferences/test_inference.py +++ b/tests/test-inferences/test_inference.py @@ -27,9 +27,9 @@ def test_latent_vars(self): qmu_misshape = Normal(mu=tf.constant([0.0]), sigma=tf.constant([1.0])) ed.Inference({mu: qmu}) + ed.Inference({mu: tf.constant(0.0)}) + ed.Inference({tf.constant(0.0): qmu}) self.assertRaises(TypeError, ed.Inference, {mu: '5'}) - self.assertRaises(TypeError, ed.Inference, {mu: tf.constant(0.0)}) - self.assertRaises(TypeError, ed.Inference, {tf.constant(0.0): qmu}) self.assertRaises(TypeError, ed.Inference, {mu: qmu_misshape}) def test_data(self): @@ -49,6 +49,7 @@ def test_data(self): ed.Inference(data={x: False}) # converted to `int` ed.Inference(data={x: x_ph}) ed.Inference(data={x: qx}) + ed.Inference(data={2.0 * x: tf.constant(0.0)}) self.assertRaises(TypeError, ed.Inference, data={5: tf.constant(0.0)}) self.assertRaises(TypeError, ed.Inference, data={x: tf.zeros(5)}) self.assertRaises(TypeError, ed.Inference, data={x_ph: x})