Skip to content

Commit c63b9e2

Browse files
committed
another attempt at the MOG prior
1 parent 93d183f commit c63b9e2

File tree

2 files changed

+14
-30
lines changed

2 files changed

+14
-30
lines changed

bayesian_rnn.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tensorflow.contrib.rnn import static_rnn, LSTMStateTuple
66

77
from stochastic_variables import get_random_normal_variable, ExternallyParameterisedLSTM
8-
from stochastic_variables import log_gaussian_mixture_sample_probabilities, log_gaussian_sample_probabilities
8+
from stochastic_variables import gaussian_mixture_nll
99
import logging
1010

1111
logger = logging.getLogger(__name__)
@@ -145,21 +145,17 @@ def build_rnn(self):
145145
[softmax_b, softmax_b_mean, softmax_b_std]]:
146146

147147
# TODO(Mark): get this to work with the MOG prior using sampling.
148-
# bernoulli_samples = tf.floor(0.8 + tf.random_uniform(tf.shape(weight), minval=0.0, maxval=1.0))
149-
# mean1 = mean2 = tf.zeros_like(mean)
150-
# # Very pointy one:
151-
# std1 = 0.0009 * tf.ones_like(std)
152-
# # Flatter one:
153-
# std2 = 0.15 * tf.ones_like(std)
154-
# phi_log_probs = log_gaussian_sample_probabilities(weight, mean, std)
155-
# phi_mixture_log_probs = \
156-
# log_gaussian_mixture_sample_probabilities(weight, bernoulli_samples, mean1, mean2, std1, std2)
157-
# kl = tf.exp(phi_log_probs) * (phi_log_probs - phi_mixture_log_probs)
158-
# phi_kl += kl
148+
mean1 = mean2 = tf.zeros_like(mean)
149+
# Very pointy one:
150+
std1 = 0.0009 * tf.ones_like(std)
151+
# Flatter one:
152+
std2 = 0.15 * tf.ones_like(std)
153+
phi_mixture_nll = gaussian_mixture_nll(weight, [0.6, 0.4], mean1, mean2, std1, std2)
154+
phi_kl += phi_mixture_nll
159155

160156
# This is different from the paper - just using a univariate gaussian
161157
# prior so that the KL has a closed form.
162-
phi_kl += self.compute_kl_divergence((mean, std), (tf.zeros_like(mean), tf.ones_like(std) * 0.01))
158+
#phi_kl += self.compute_kl_divergence((mean, std), (tf.zeros_like(mean), tf.ones_like(std) * 0.01))
163159

164160
tf.summary.scalar("phi_kl", phi_kl)
165161

stochastic_variables.py

+5-17
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,17 @@
77
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _checked_scope
88

99

10-
def log_gaussian_sample_probabilities(samples, mean, std):
10+
def gaussian_mixture_nll(samples, mixing_weights, mean1, mean2, std1, std2):
1111
"""
12-
Computes the log probability that the samples were drawn from a gaussian distribution
13-
with the given mean and standard deviation.
14-
"""
15-
pi_sigma = - 0.5 * tf.log(2.0 * std * math.pi)
16-
mean_shift = tf.square(samples - mean) / (2.0 * std)
17-
18-
return pi_sigma - mean_shift
19-
20-
21-
def log_gaussian_mixture_sample_probabilities(samples, bernouli_samples, mean1, mean2, std1, std2):
22-
"""
23-
Computes the log probability that the samples were drawn from a mixture of two gaussian distributions
24-
with the given means and standard deviations, along with precomputed bernouli samples to compute the
25-
mixture.
12+
Computes the NLL from a mixture of two gaussian distributions with the given
13+
means and standard deviations, mixing weights and samples.
2614
"""
2715
gaussian1 = (1.0/tf.sqrt(2.0 * std1 * math.pi)) * tf.exp(- tf.square(samples - mean1) / (2.0 * std1))
2816
gaussian2 = (1.0/tf.sqrt(2.0 * std2 * math.pi)) * tf.exp(- tf.square(samples - mean2) / (2.0 * std2))
2917

30-
mixture = bernouli_samples * gaussian1 + (1.0 - bernouli_samples) * gaussian2
18+
mixture = (mixing_weights[0] * gaussian1) + (mixing_weights[1] * gaussian2)
3119

32-
return tf.log(mixture)
20+
return - tf.log(mixture)
3321

3422

3523
def get_random_normal_variable(name, mean, standard_dev, shape, dtype):

0 commit comments

Comments
 (0)