|
5 | 5 | from tensorflow.contrib.rnn import static_rnn, LSTMStateTuple
|
6 | 6 |
|
7 | 7 | from stochastic_variables import get_random_normal_variable, ExternallyParameterisedLSTM
|
8 |
| -from stochastic_variables import log_gaussian_mixture_sample_probabilities, log_gaussian_sample_probabilities |
| 8 | +from stochastic_variables import gaussian_mixture_nll |
9 | 9 | import logging
|
10 | 10 |
|
11 | 11 | logger = logging.getLogger(__name__)
|
@@ -145,21 +145,17 @@ def build_rnn(self):
|
145 | 145 | [softmax_b, softmax_b_mean, softmax_b_std]]:
|
146 | 146 |
|
147 | 147 | # TODO(Mark): get this to work with the MOG prior using sampling.
|
148 |
| - # bernoulli_samples = tf.floor(0.8 + tf.random_uniform(tf.shape(weight), minval=0.0, maxval=1.0)) |
149 |
| - # mean1 = mean2 = tf.zeros_like(mean) |
150 |
| - # # Very pointy one: |
151 |
| - # std1 = 0.0009 * tf.ones_like(std) |
152 |
| - # # Flatter one: |
153 |
| - # std2 = 0.15 * tf.ones_like(std) |
154 |
| - # phi_log_probs = log_gaussian_sample_probabilities(weight, mean, std) |
155 |
| - # phi_mixture_log_probs = \ |
156 |
| - # log_gaussian_mixture_sample_probabilities(weight, bernoulli_samples, mean1, mean2, std1, std2) |
157 |
| - # kl = tf.exp(phi_log_probs) * (phi_log_probs - phi_mixture_log_probs) |
158 |
| - # phi_kl += kl |
| 148 | + mean1 = mean2 = tf.zeros_like(mean) |
| 149 | + # Very pointy one: |
| 150 | + std1 = 0.0009 * tf.ones_like(std) |
| 151 | + # Flatter one: |
| 152 | + std2 = 0.15 * tf.ones_like(std) |
| 153 | + phi_mixture_nll = gaussian_mixture_nll(weight, [0.6, 0.4], mean1, mean2, std1, std2) |
| 154 | + phi_kl += phi_mixture_nll |
159 | 155 |
|
160 | 156 | # This is different from the paper - just using a univariate gaussian
|
161 | 157 | # prior so that the KL has a closed form.
|
162 |
| - phi_kl += self.compute_kl_divergence((mean, std), (tf.zeros_like(mean), tf.ones_like(std) * 0.01)) |
| 158 | + #phi_kl += self.compute_kl_divergence((mean, std), (tf.zeros_like(mean), tf.ones_like(std) * 0.01)) |
163 | 159 |
|
164 | 160 | tf.summary.scalar("phi_kl", phi_kl)
|
165 | 161 |
|
|
0 commit comments