Skip to content

Commit

Permalink
simplify placeholder if/else logic (#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran authored Feb 28, 2017
1 parent 176e2c2 commit 06c4a57
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 18 deletions.
5 changes: 2 additions & 3 deletions edward/inferences/gan_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,8 @@ def update(self, feed_dict=None, variables=None):
feed_dict = {}

for key, value in six.iteritems(self.data):
if isinstance(key, tf.Tensor):
if "Placeholder" in key.op.type:
feed_dict[key] = value
if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type:
feed_dict[key] = value

sess = get_session()
if variables is None:
Expand Down
10 changes: 4 additions & 6 deletions edward/inferences/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,8 @@ def run(self, variables=None, use_coordinator=True, *args, **kwargs):
# Feed placeholders in case initialization depends on them.
feed_dict = {}
for key, value in six.iteritems(self.data):
if isinstance(key, tf.Tensor):
if "Placeholder" in key.op.type:
feed_dict[key] = value
if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type:
feed_dict[key] = value

init.run(feed_dict)

Expand Down Expand Up @@ -375,9 +374,8 @@ def update(self, feed_dict=None):
feed_dict = {}

for key, value in six.iteritems(self.data):
if isinstance(key, tf.Tensor):
if "Placeholder" in key.op.type:
feed_dict[key] = value
if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type:
feed_dict[key] = value

sess = get_session()
t = sess.run(self.increment_t)
Expand Down
5 changes: 2 additions & 3 deletions edward/inferences/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,8 @@ def update(self, feed_dict=None):
feed_dict = {}

for key, value in six.iteritems(self.data):
if isinstance(key, tf.Tensor):
if "Placeholder" in key.op.type:
feed_dict[key] = value
if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type:
feed_dict[key] = value

sess = get_session()
_, accept_rate = sess.run([self.train, self.n_accept_over_t], feed_dict)
Expand Down
5 changes: 2 additions & 3 deletions edward/inferences/variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,8 @@ def update(self, feed_dict=None):
feed_dict = {}

for key, value in six.iteritems(self.data):
if isinstance(key, tf.Tensor):
if "Placeholder" in key.op.type:
feed_dict[key] = value
if isinstance(key, tf.Tensor) and "Placeholder" in key.op.type:
feed_dict[key] = value

sess = get_session()
_, t, loss = sess.run([self.train, self.increment_t, self.loss], feed_dict)
Expand Down
6 changes: 3 additions & 3 deletions edward/util/random_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def copy(org_instance, dict_swap=None, scope="copied",
return graph.get_tensor_by_name(variables[org_instance.name].name)

# Do the same for placeholders. Determine via its op's type.
if isinstance(org_instance, tf.Tensor):
if "Placeholder" in org_instance.op.type:
return org_instance
if isinstance(org_instance, tf.Tensor) and \
"Placeholder" in org_instance.op.type:
return org_instance

if isinstance(org_instance, RandomVariable):
rv = org_instance
Expand Down

0 comments on commit 06c4a57

Please sign in to comment.