Skip to content

Commit

Permalink
fix pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Mar 20, 2017
1 parent 693b0b6 commit a01b7e1
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions edward/models/dirichlet_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,29 +155,29 @@ def _sample_n_body(self, k, bools, theta, draws):
n, batch_shape, event_shape, rank = self._temp_scope
k += 1

# If necessary, add a new persistent state to theta.
def fn():
theta_k = self._base_cls(
*self._base_args, **self._base_kwargs).sample(batch_shape)
return tf.concat([theta, tf.expand_dims(theta_k, 0)], 0)

# If necessary, add a new persistent state to theta.
theta = tf.cond(tf.shape(theta)[0] - 1 >= k, lambda: theta, fn)
theta_k = tf.gather(theta, k)

# Assign True samples to the new theta_k.
if len(bools.get_shape()) <= 1:
bools_broadcast = bools
bools_tile = bools
else:
# ``tf.where`` only index subsets when ``bools`` is at most a
# vector. In general, ``bools`` has shape (n, batch_shape).
# Therefore we tile ``bools`` to be of shape
# (n, batch_shape, event_shape) in order to index per-element.
bools_broadcast = tf.tile(tf.reshape(
bools_tile = tf.tile(tf.reshape(
bools, [n] + batch_shape + [1] * len(event_shape)),
[1] + [1] * len(batch_shape) + event_shape)

# Assign True samples to the new theta_k.
theta_k_broadcast = tf.tile(tf.expand_dims(theta_k, 0), [n] + [1] * (rank - 1))
draws = tf.where(bools_broadcast, theta_k_broadcast, draws)
theta_k_tile = tf.tile(tf.expand_dims(theta_k, 0), [n] + [1] * (rank - 1))
draws = tf.where(bools_tile, theta_k_tile, draws)

# Draw new stick probability, then flip coin.
beta_k = Beta(a=tf.ones_like(self.alpha), b=self.alpha).sample(n)
Expand Down

0 comments on commit a01b7e1

Please sign in to comment.