Skip to content

Commit

Permalink
fix pep8 from #808
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Jan 8, 2018
1 parent 7648572 commit 4abda1a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
14 changes: 7 additions & 7 deletions edward/inferences/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def build_update(self):
# Update Empirical random variables.
assign_ops = []
for z_unconstrained, qz_unconstrained in six.iteritems(
self.latent_vars_unconstrained):
self.latent_vars_unconstrained):
variable = qz_unconstrained.get_variables()[0]
assign_ops.append(tf.scatter_update(
variable, self.t, sample[z_unconstrained]))
Expand All @@ -139,7 +139,7 @@ def build_update(self):

def _log_joint_unconstrained(self, z_sample):
"""
Given a sample in unconstrained latent space, transform it back into
Given a sample in unconstrained latent space, transform it back into
the original space, and compute the log joint density with appropriate
Jacobian correction.
"""
Expand All @@ -151,17 +151,17 @@ def _log_joint_unconstrained(self, z_sample):
z_sample_transformed = {}
log_det_jacobian = 0.0
for z_unconstrained, qz_unconstrained in z_sample.items():
z = (unconstrained_to_z[z_unconstrained]
if z_unconstrained in unconstrained_to_z
z = (unconstrained_to_z[z_unconstrained]
if z_unconstrained in unconstrained_to_z
else z_unconstrained)

try:
bij = self.transformations[z].bijector
z_sample_transformed[z] = bij.inverse(qz_unconstrained)
log_det_jacobian += tf.reduce_sum(
bij.inverse_log_det_jacobian(qz_unconstrained))
except: # if z not in self.transformations,
# or is not a TransformedDist w/ bijector
bij.inverse_log_det_jacobian(qz_unconstrained))
except: # if z not in self.transformations,
# or is not a TransformedDist w/ bijector
z_sample_transformed[z] = qz_unconstrained

return self._log_joint(z_sample_transformed) + log_det_jacobian
Expand Down
17 changes: 10 additions & 7 deletions edward/inferences/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from tensorflow.contrib.distributions import bijectors


@six.add_metaclass(abc.ABCMeta)
class Inference(object):
"""Abstract base class for inference. All inference algorithms in
Expand Down Expand Up @@ -222,11 +223,13 @@ def initialize(self, n_iter=1000, n_print=None, scale=None,
self.transformations = {}
if auto_transform:
latent_vars = self.latent_vars.copy()
self.latent_vars = {} # maps original latent vars to constrained Q's
self.latent_vars_unconstrained = {} # maps unconstrained vars to unconstrained Q's
# latent_vars maps original latent vars to constrained Q's.
# latent_vars_unconstrained maps unconstrained vars to unconstrained Q's.
self.latent_vars = {}
self.latent_vars_unconstrained = {}
for z, qz in six.iteritems(latent_vars):
if hasattr(z, 'support') and hasattr(qz, 'support') and \
z.support != qz.support and qz.support != 'point':
z.support != qz.support and qz.support != 'point':

# transform z to an unconstrained space
z_unconstrained = transform(z)
Expand All @@ -243,12 +246,12 @@ def initialize(self, n_iter=1000, n_print=None, scale=None,
# back into the original constrained space
if z_unconstrained != z:
qz_constrained = transform(
qz_unconstrained, bijectors.Invert(z_unconstrained.bijector))
qz_unconstrained, bijectors.Invert(z_unconstrained.bijector))

try: # attempt to pushforward the params of Empirical distributions
try: # attempt to pushforward the params of Empirical distributions
qz_constrained.params = z_unconstrained.bijector.inverse(
qz_unconstrained.params)
except: # qz_unconstrained is not an Empirical distribution
qz_unconstrained.params)
except: # qz_unconstrained is not an Empirical distribution
pass

else:
Expand Down

0 comments on commit 4abda1a

Please sign in to comment.