Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 3e295e7

Browse files
Lukasz KaiserRyan Sepassi
Lukasz Kaiser
authored and
Ryan Sepassi
committed
Play with VAE more, bump version.
PiperOrigin-RevId: 165523404
1 parent 8f99d47 commit 3e295e7

File tree

2 files changed

+57
-14
lines changed

2 files changed

+57
-14
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.1.8',
8+
version='1.1.9',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='no-reply@google.com',

tensor2tensor/models/transformer_vae.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tensor2tensor.layers import common_attention
2727
from tensor2tensor.layers import common_layers
2828
from tensor2tensor.models import transformer
29+
from tensor2tensor.utils import expert_utils
2930
from tensor2tensor.utils import registry
3031
from tensor2tensor.utils import t2t_model
3132

@@ -84,12 +85,37 @@ def decompress_step(source, c, hparams, first_relu, name):
8485
return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size])
8586

8687

87-
def dvae(x, hparams, name):
88+
def top_k_softmax(x, k):
89+
"""Calculate softmax(x), select top-k and rescale to sum to 1."""
90+
x = tf.nn.softmax(x)
91+
top_x, _ = tf.nn.top_k(x, k=k+1)
92+
min_top = tf.reduce_min(top_x, axis=-1, keep_dims=True)
93+
x = tf.nn.relu((x - min_top) + 1e-12)
94+
x /= tf.reduce_sum(x, axis=-1, keep_dims=True)
95+
return x, tf.reduce_max(top_x, axis=-1)
96+
97+
98+
def top_k_experts(x, k, hparams):
99+
x_shape = tf.shape(x)
100+
x_flat = tf.reshape(x, [-1, x.get_shape().as_list()[-1]])
101+
is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
102+
gates, load = expert_utils.noisy_top_k_gating(
103+
x_flat, hparams.v_size, is_training, k)
104+
gates_shape = [x_shape[0], x_shape[1], x_shape[2], hparams.v_size]
105+
gates = tf.reshape(gates, gates_shape)
106+
load_loss = expert_utils.cv_squared(load)
107+
return gates, load_loss
108+
109+
110+
def dvae(x, k, hparams, name):
88111
with tf.variable_scope(name):
89112
m = tf.layers.dense(x, hparams.v_size, name="mask")
90-
m = tf.nn.softmax(m)
91-
kl = - tf.reduce_max(m, axis=-1)
92-
return m, tf.reduce_mean(kl)
113+
if k is None:
114+
m = tf.nn.softmax(m)
115+
kl = - tf.reduce_max(m, axis=-1)
116+
else:
117+
m, kl = top_k_softmax(m, k)
118+
return m, 1.0 - tf.reduce_mean(kl)
93119

94120

95121
def vae(x, hparams, name):
@@ -119,42 +145,59 @@ def compress(x, c, hparams, name):
119145
return cur
120146

121147

148+
def mix(x1, x2, steps, min_prob=0.0, max_prob=1.0, mode="lin"):
149+
if mode == "lin":
150+
alpha_p = common_layers.inverse_lin_decay(steps) + 0.001
151+
else:
152+
alpha_p = common_layers.inverse_exp_decay(steps) + 0.001
153+
alpha_p = alpha_p * (max_prob - min_prob) + min_prob
154+
alpha = tf.random_uniform(tf.shape(x1))
155+
alpha = tf.to_float(tf.less(alpha, alpha_p))
156+
return alpha * x1 + (1.0 - alpha) * x2
157+
158+
122159
def vae_compress(x, c, hparams, compress_name, decompress_name, reuse=None):
123160
"""Compress, then VAE."""
161+
mix_k = 8
124162
with tf.variable_scope(compress_name, reuse=reuse):
125163
cur = compress(x, None, hparams, "compress")
126164
# Convolve and ReLu to get state.
127165
cur = common_layers.conv_block(
128166
cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv")
129167
# z, kl_loss, mu, log_sigma = vae(cur, hparams, name="vae")
130-
z, kl_loss = dvae(cur, hparams, name="dvae")
168+
z, kl_loss = dvae(cur, None, hparams, name="dvae")
169+
z1, kl_loss1 = top_k_experts(cur, mix_k, hparams)
131170
mu, log_sigma = None, None
132171

172+
# Mix expert-selection and flat selection.
173+
alpha_p = common_layers.inverse_lin_decay(60000) + 0.001
174+
z = alpha_p * z1 + (1 - alpha_p) * z
175+
kl_loss += kl_loss1
176+
133177
# Compress context.
134178
with tf.variable_scope(compress_name, reuse=reuse):
135179
compress_c = compress(c, None, hparams, "compress_context")
136180
c_z = tf.layers.dense(compress_c, hparams.v_size, name="mask_context")
137181
reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits(
138182
labels=z, logits=c_z)
139183

184+
# If not training, use the predicted z instead of the autoregressive one.
185+
# if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN:
186+
# z = mix(c_z, z, 50000, max_prob=0.3, mode="exp")
187+
# z, _ = top_k_softmax(c_z, mix_k)
188+
140189
with tf.variable_scope(decompress_name, reuse=reuse):
141190
# Decompress.
142191
z = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense")
143192

144193
# Leak at the beginning to help train.
145-
alpha_p = common_layers.inverse_lin_decay(30000) + 0.001
146-
alpha = tf.random_uniform(tf.shape(cur))
147-
alpha = tf.to_float(tf.less(alpha, alpha_p))
148-
z = alpha * z + (1.0 - alpha) * cur
149-
150-
# TODO(lukaszkaiser): If not training, use the predicted z.
151-
# is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN
194+
z = mix(z, cur, 30000)
152195

153196
for i in xrange(hparams.num_compress_steps):
154197
j = hparams.num_compress_steps - i - 1
155198
z = residual_conv(z, 1, hparams, "decompress_rc_%d" % j)
156199
z = decompress_step(z, c, hparams, i > 0, "decompress_step_%d" % j)
157-
return z, kl_loss + 0.001 * reconstruct_loss, mu, log_sigma
200+
return z, kl_loss + 0.0001 * reconstruct_loss, mu, log_sigma
158201

159202

160203
def encode(x, x_space, hparams, name):

0 commit comments

Comments
 (0)