Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 151dc27

Browse files
T2T Teamcopybara-github
T2T Team
authored andcommitted
"Adding mixture transformer"
PiperOrigin-RevId: 240229309
1 parent 150aad3 commit 151dc27

15 files changed

+267
-894
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1790,6 +1790,98 @@ def padded_cross_entropy(logits,
17901790
return tf.reduce_sum(xent * weights), tf.reduce_sum(weights)
17911791

17921792

1793+
def padded_cross_entropy_mixture(logits,
1794+
labels,
1795+
label_smoothing,
1796+
num_mixtures,
1797+
weights_fn=weights_nonzero,
1798+
reduce_sum=False,
1799+
cutoff=0.0,
1800+
gaussian=False,
1801+
return_best_logits=False):
1802+
"""Compute cross-entropy assuming 0s are padding.
1803+
1804+
Computes a loss numerator (the sum of losses), and loss denominator
1805+
(the number of non-padding tokens).
1806+
1807+
Computes cross-entropy for each mixture, and returns the corresponding values
1808+
for the mixture with the highest probability
1809+
1810+
Args:
1811+
logits: `Tensor` with shape `[batch * num_mixtures, timesteps, vocab_size]`.
1812+
optionally a FactoredTensor.
1813+
labels: an integer `Tensor` with shape `[batch, timesteps]`.
1814+
label_smoothing: a floating point `Scalar`.
1815+
num_mixtures: an integer.
1816+
weights_fn: A function from labels to weights.
1817+
reduce_sum: a Boolean, whether to sum at the end or not.
1818+
cutoff: a float, at which point to have no loss.
1819+
gaussian: If true, use a Gaussian distribution for label smoothing
1820+
return_best_logits: If true, return the logits of the mixture with highest
1821+
probabilities for an example
1822+
1823+
Returns:
1824+
loss_numerator: a `Scalar`. Sum of losses.
1825+
loss_denominator: a `Scalar. The number of non-padding target tokens.
1826+
1827+
Raises:
1828+
ValueError: in case of unsupported argument types.
1829+
"""
1830+
logit_shapes = shape_list(
1831+
logits) # batch_size * num_mixtures, timesteps, 1, 1, vocab_size
1832+
batch_size = tf.cast(logit_shapes[0] / num_mixtures, dtype=tf.int32)
1833+
timesteps = logit_shapes[1]
1834+
vocab_size = logit_shapes[4]
1835+
1836+
new_shape_for_xent = [num_mixtures] + shape_list(labels)
1837+
labels = tf.tile(labels, [num_mixtures, 1, 1, 1])
1838+
1839+
xent, weights = padded_cross_entropy(
1840+
logits, labels, label_smoothing, weights_fn, reduce_sum, cutoff, gaussian)
1841+
1842+
# reshape xent and weights to have the num_mixtures as first dimension
1843+
xent = tf.reshape(xent, new_shape_for_xent)
1844+
weights = tf.reshape(weights, new_shape_for_xent[:-1])
1845+
1846+
# sum up sentence neg log probs
1847+
xent = tf.reduce_sum(xent, axis=2)
1848+
1849+
# if we need to compute the best logits
1850+
if return_best_logits:
1851+
best_mixture_indices = tf.cast(tf.argmin(xent, 0), dtype=tf.int32)
1852+
individual_element_indices = tf.range(batch_size)
1853+
stacked_mixture_element_indices = tf.stack(
1854+
(tf.squeeze(best_mixture_indices), individual_element_indices), -1)
1855+
best_logits = tf.reshape(logits,
1856+
[num_mixtures, -1, timesteps, 1, 1, vocab_size])
1857+
best_logits = tf.gather_nd(best_logits, stacked_mixture_element_indices)
1858+
best_logits = tf.reshape(best_logits,
1859+
[batch_size, timesteps, 1, 1, vocab_size])
1860+
1861+
with tf.control_dependencies([
1862+
tf.assert_equal(
1863+
tf.shape(xent)[:3], [num_mixtures, batch_size, 1],
1864+
message="Each batch element should have a probability value for each mixture element"
1865+
)
1866+
]):
1867+
xent = tf.reduce_min(xent, axis=0)
1868+
weights = tf.reduce_mean(weights, axis=0)
1869+
1870+
with tf.control_dependencies([
1871+
tf.assert_equal(
1872+
tf.shape(xent)[0], [batch_size],
1873+
message="There should be batch_size elements after selecting best mixture probabilities"
1874+
)
1875+
]):
1876+
summed_xent = tf.reduce_sum(xent)
1877+
summed_weights = tf.reduce_sum(weights)
1878+
1879+
if return_best_logits:
1880+
return summed_xent, summed_weights, best_logits
1881+
else:
1882+
return summed_xent, summed_weights
1883+
1884+
17931885
def _weights_one_third(labels):
17941886
"""Returns Tensor of shape [batch, height, width]. Each element is 1/3."""
17951887
return tf.ones(tf.shape(labels)[:-1]) / 3.

tensor2tensor/layers/common_video.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -790,8 +790,6 @@ def finish(self):
790790
(out, err) = [
791791
b"".join(chunks) for chunks in (self._out_chunks, self._err_chunks)
792792
]
793-
self.proc.stdout.close()
794-
self.proc.stderr.close()
795793
if self.proc.returncode:
796794
err = "\n".join([" ".join(self.cmd), err.decode("utf8")])
797795
raise IOError(err)

tensor2tensor/models/research/rl.py

Lines changed: 9 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -354,18 +354,12 @@ def dqn_atari_base():
354354
optimizer_epsilon=0.00001,
355355
optimizer_centered=True,
356356

357-
# TODO: change names maybe replay_buffer -> agent? Also batch_size is now
358-
# buffer_batch_size in _DQNAgent.
359357
replay_buffer_replay_capacity=1000000,
360-
replay_buffer_buffer_batch_size=32,
358+
replay_buffer_batch_size=32,
361359

362360
time_limit=27000,
363361
save_every_steps=50000,
364362
num_frames=int(20 * 1e6),
365-
366-
# TODO(konradczechowski) this is not used in trainer_model_free, clean
367-
# this up after evaluation refactor
368-
eval_episodes_num=3,
369363
)
370364

371365

@@ -376,16 +370,6 @@ def dqn_original_params():
376370
hparams.set_hparam("num_frames", int(1e6))
377371
return hparams
378372

379-
def rlmf_tiny_overrides():
380-
"""Parameters to override for tiny setting excluding agent-related hparams."""
381-
return dict(
382-
max_num_noops=1,
383-
eval_max_num_noops=1,
384-
rl_env_max_episode_steps=7,
385-
eval_rl_env_max_episode_steps=7,
386-
eval_sampling_temps=[0.0, 1.0],
387-
)
388-
389373

390374
@registry.register_hparams
391375
def rlmf_original():
@@ -398,7 +382,6 @@ def rlmf_original():
398382
eval_batch_size=2,
399383
frame_stack_size=4,
400384
eval_sampling_temps=[0.0, 0.2, 0.5, 0.8, 1.0, 2.0],
401-
max_num_noops=8,
402385
eval_max_num_noops=8,
403386
eval_rl_env_max_episode_steps=1000,
404387
resize_height_factor=2,
@@ -443,31 +426,6 @@ def rlmf_base():
443426
return hparams
444427

445428

446-
@registry.register_hparams
447-
def rlmf_tiny():
448-
"""Tiny set of hparams for model-free PPO."""
449-
hparams = rlmf_original()
450-
hparams = hparams.override_from_dict(rlmf_tiny_overrides())
451-
hparams.batch_size = 2
452-
hparams.add_hparam("ppo_epochs_num", 3)
453-
hparams.add_hparam("ppo_epoch_length", 2)
454-
return hparams
455-
456-
457-
@registry.register_hparams
458-
def rlmf_dqn_tiny():
459-
hparams = rlmf_original()
460-
hparams = hparams.override_from_dict(rlmf_tiny_overrides())
461-
hparams.batch_size = 1
462-
hparams.base_algo = "dqn"
463-
hparams.base_algo_params = "dqn_original_params"
464-
hparams.add_hparam("dqn_num_frames", 128)
465-
hparams.add_hparam("dqn_save_every_steps", 128)
466-
hparams.add_hparam("dqn_replay_buffer_replay_capacity", 100)
467-
hparams.add_hparam("dqn_agent_min_replay_history", 10)
468-
return hparams
469-
470-
471429
@registry.register_hparams
472430
def rlmf_eval():
473431
"""Eval set of hparams for model-free PPO."""
@@ -484,6 +442,14 @@ def rlmf_eval():
484442
return hparams
485443

486444

445+
@registry.register_hparams
446+
def rlmf_tiny():
447+
hparams = rlmf_base()
448+
hparams.ppo_epochs_num = 100
449+
hparams.ppo_eval_every_epochs = 10
450+
return hparams
451+
452+
487453
class PolicyBase(t2t_model.T2TModel):
488454

489455
def loss(self, *args, **kwargs):

tensor2tensor/rl/batch_dqn_agent_test.py

Lines changed: 0 additions & 157 deletions
This file was deleted.

0 commit comments

Comments
 (0)