|
26 | 26 | from tensor2tensor.layers import common_attention
|
27 | 27 | from tensor2tensor.layers import common_layers
|
28 | 28 | from tensor2tensor.models import transformer
|
| 29 | +from tensor2tensor.utils import expert_utils |
29 | 30 | from tensor2tensor.utils import registry
|
30 | 31 | from tensor2tensor.utils import t2t_model
|
31 | 32 |
|
@@ -84,12 +85,37 @@ def decompress_step(source, c, hparams, first_relu, name):
|
84 | 85 | return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size])
|
85 | 86 |
|
86 | 87 |
|
87 |
| -def dvae(x, hparams, name): |
| 88 | +def top_k_softmax(x, k): |
| 89 | + """Calculate softmax(x), select top-k and rescale to sum to 1.""" |
| 90 | + x = tf.nn.softmax(x) |
| 91 | + top_x, _ = tf.nn.top_k(x, k=k+1) |
| 92 | + min_top = tf.reduce_min(top_x, axis=-1, keep_dims=True) |
| 93 | + x = tf.nn.relu((x - min_top) + 1e-12) |
| 94 | + x /= tf.reduce_sum(x, axis=-1, keep_dims=True) |
| 95 | + return x, tf.reduce_max(top_x, axis=-1) |
| 96 | + |
| 97 | + |
| 98 | +def top_k_experts(x, k, hparams): |
| 99 | + x_shape = tf.shape(x) |
| 100 | + x_flat = tf.reshape(x, [-1, x.get_shape().as_list()[-1]]) |
| 101 | + is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN |
| 102 | + gates, load = expert_utils.noisy_top_k_gating( |
| 103 | + x_flat, hparams.v_size, is_training, k) |
| 104 | + gates_shape = [x_shape[0], x_shape[1], x_shape[2], hparams.v_size] |
| 105 | + gates = tf.reshape(gates, gates_shape) |
| 106 | + load_loss = expert_utils.cv_squared(load) |
| 107 | + return gates, load_loss |
| 108 | + |
| 109 | + |
| 110 | +def dvae(x, k, hparams, name): |
88 | 111 | with tf.variable_scope(name):
|
89 | 112 | m = tf.layers.dense(x, hparams.v_size, name="mask")
|
90 |
| - m = tf.nn.softmax(m) |
91 |
| - kl = - tf.reduce_max(m, axis=-1) |
92 |
| - return m, tf.reduce_mean(kl) |
| 113 | + if k is None: |
| 114 | + m = tf.nn.softmax(m) |
| 115 | + kl = - tf.reduce_max(m, axis=-1) |
| 116 | + else: |
| 117 | + m, kl = top_k_softmax(m, k) |
| 118 | + return m, 1.0 - tf.reduce_mean(kl) |
93 | 119 |
|
94 | 120 |
|
95 | 121 | def vae(x, hparams, name):
|
@@ -119,42 +145,59 @@ def compress(x, c, hparams, name):
|
119 | 145 | return cur
|
120 | 146 |
|
121 | 147 |
|
| 148 | +def mix(x1, x2, steps, min_prob=0.0, max_prob=1.0, mode="lin"): |
| 149 | + if mode == "lin": |
| 150 | + alpha_p = common_layers.inverse_lin_decay(steps) + 0.001 |
| 151 | + else: |
| 152 | + alpha_p = common_layers.inverse_exp_decay(steps) + 0.001 |
| 153 | + alpha_p = alpha_p * (max_prob - min_prob) + min_prob |
| 154 | + alpha = tf.random_uniform(tf.shape(x1)) |
| 155 | + alpha = tf.to_float(tf.less(alpha, alpha_p)) |
| 156 | + return alpha * x1 + (1.0 - alpha) * x2 |
| 157 | + |
| 158 | + |
122 | 159 | def vae_compress(x, c, hparams, compress_name, decompress_name, reuse=None):
|
123 | 160 | """Compress, then VAE."""
|
| 161 | + mix_k = 8 |
124 | 162 | with tf.variable_scope(compress_name, reuse=reuse):
|
125 | 163 | cur = compress(x, None, hparams, "compress")
|
126 | 164 | # Convolve and ReLu to get state.
|
127 | 165 | cur = common_layers.conv_block(
|
128 | 166 | cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv")
|
129 | 167 | # z, kl_loss, mu, log_sigma = vae(cur, hparams, name="vae")
|
130 |
| - z, kl_loss = dvae(cur, hparams, name="dvae") |
| 168 | + z, kl_loss = dvae(cur, None, hparams, name="dvae") |
| 169 | + z1, kl_loss1 = top_k_experts(cur, mix_k, hparams) |
131 | 170 | mu, log_sigma = None, None
|
132 | 171 |
|
| 172 | + # Mix expert-selection and flat selection. |
| 173 | + alpha_p = common_layers.inverse_lin_decay(60000) + 0.001 |
| 174 | + z = alpha_p * z1 + (1 - alpha_p) * z |
| 175 | + kl_loss += kl_loss1 |
| 176 | + |
133 | 177 | # Compress context.
|
134 | 178 | with tf.variable_scope(compress_name, reuse=reuse):
|
135 | 179 | compress_c = compress(c, None, hparams, "compress_context")
|
136 | 180 | c_z = tf.layers.dense(compress_c, hparams.v_size, name="mask_context")
|
137 | 181 | reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits(
|
138 | 182 | labels=z, logits=c_z)
|
139 | 183 |
|
| 184 | + # If not training, use the predicted z instead of the autoregressive one. |
| 185 | + # if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN: |
| 186 | + # z = mix(c_z, z, 50000, max_prob=0.3, mode="exp") |
| 187 | + # z, _ = top_k_softmax(c_z, mix_k) |
| 188 | + |
140 | 189 | with tf.variable_scope(decompress_name, reuse=reuse):
|
141 | 190 | # Decompress.
|
142 | 191 | z = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense")
|
143 | 192 |
|
144 | 193 | # Leak at the beginning to help train.
|
145 |
| - alpha_p = common_layers.inverse_lin_decay(30000) + 0.001 |
146 |
| - alpha = tf.random_uniform(tf.shape(cur)) |
147 |
| - alpha = tf.to_float(tf.less(alpha, alpha_p)) |
148 |
| - z = alpha * z + (1.0 - alpha) * cur |
149 |
| - |
150 |
| - # TODO(lukaszkaiser): If not training, use the predicted z. |
151 |
| - # is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN |
| 194 | + z = mix(z, cur, 30000) |
152 | 195 |
|
153 | 196 | for i in xrange(hparams.num_compress_steps):
|
154 | 197 | j = hparams.num_compress_steps - i - 1
|
155 | 198 | z = residual_conv(z, 1, hparams, "decompress_rc_%d" % j)
|
156 | 199 | z = decompress_step(z, c, hparams, i > 0, "decompress_step_%d" % j)
|
157 |
| - return z, kl_loss + 0.001 * reconstruct_loss, mu, log_sigma |
| 200 | + return z, kl_loss + 0.0001 * reconstruct_loss, mu, log_sigma |
158 | 201 |
|
159 | 202 |
|
160 | 203 | def encode(x, x_space, hparams, name):
|
|
0 commit comments