Skip to content

Commit 39652cd

Browse files
nshazeerkpe
authored andcommitted
Add an hparam use_global_position_in_packed_sequence in mtf_transformer2.
If True (default), then we use the global position in the packed example as the input to the positional embedding. If False, then we use the position in the individual sequence. It is counterintuitive why we want to make True the default, since False seems to make more sense. However, the previous submitted CL had the effect of changing from True to False, which caused some models to diverge. This CL restores the previous working state. TODO(noam): investigate why the models diverge with False. PiperOrigin-RevId: 233427027
1 parent e7e89bf commit 39652cd

File tree

3 files changed

+64
-34
lines changed

3 files changed

+64
-34
lines changed

tensor2tensor/data_generators/wiki_multi_problems.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -80,26 +80,30 @@ class LanguagemodelMultiWikiTranslatePacked1k(
8080
"""Wiki-LM, Translation, MNLI, SQUAD mixed problem class."""
8181

8282
def __init__(self, was_reversed=False, was_copy=False):
83-
problems = [
84-
# TODO(noam): uncommonet once data is generated
85-
wiki_lm.LanguagemodelDeEnFrRoWiki64kFitbPacked1k(),
86-
wiki_lm.LanguagemodelDeEnFrRoWiki64kFitbPacked1k(was_reversed=True),
87-
translate_ende.TranslateEndeWmtMulti64kPacked1k(),
88-
translate_ende.TranslateEndeWmtMulti64kPacked1k(was_reversed=True),
89-
translate_enfr.TranslateEnfrWmtMulti64kPacked1k(),
90-
translate_enfr.TranslateEnfrWmtMulti64kPacked1k(was_reversed=True),
91-
translate_enro.TranslateEnroWmtMultiTiny64kPacked1k(),
92-
translate_enro.TranslateEnroWmtMultiTiny64kPacked1k(was_reversed=True),
93-
cnn_dailymail.SummarizeCnnDailymailMulti64kPacked1k(),
94-
cnn_dailymail.SummarizeCnnDailymailMulti64kPacked1k(was_reversed=True),
95-
multinli.MultiNLIText2textMulti64kPacked1k(),
96-
squad.SquadText2textMulti64kPacked1k(),
97-
]
98-
schedule = multi_problem_v2.constant_schedule(
99-
multi_problem_v2.epoch_rates_to_pmf(problems))
83+
problems = []
84+
rates = []
85+
for rate, also_reverse, cls in self.problems_and_rates:
86+
for r in [False, True] if also_reverse else [False]:
87+
problems.append(cls(was_reversed=r))
88+
rates.append(rate)
89+
pmf = multi_problem_v2.epoch_rates_to_pmf(problems, epoch_rates=rates)
90+
schedule = multi_problem_v2.constant_schedule(pmf)
10091
super(LanguagemodelMultiWikiTranslatePacked1k, self).__init__(
10192
problems, schedule, was_reversed=was_reversed, was_copy=was_copy)
10293

94+
@property
95+
def problems_and_rates(self):
96+
"""Returns a list of (weight, also_reverse, problem_class) triples."""
97+
return [
98+
(1.0, True, wiki_lm.LanguagemodelDeEnFrRoWiki64kFitbPacked1k),
99+
(1.0, True, translate_ende.TranslateEndeWmtMulti64kPacked1k),
100+
(1.0, True, translate_enfr.TranslateEnfrWmtMulti64kPacked1k),
101+
(1.0, True, translate_enro.TranslateEnroWmtMultiTiny64kPacked1k),
102+
(1.0, True, cnn_dailymail.SummarizeCnnDailymailMulti64kPacked1k),
103+
(1.0, False, multinli.MultiNLIText2textMulti64kPacked1k),
104+
(1.0, False, squad.SquadText2textMulti64kPacked1k),
105+
]
106+
103107
@property
104108
def has_inputs(self):
105109
return True
@@ -117,6 +121,25 @@ def packed_length(self):
117121
return 1024
118122

119123

124+
@registry.register_problem
125+
class LanguagemodelMultiWikiTranslatePacked1kV2(
126+
LanguagemodelMultiWikiTranslatePacked1k):
127+
"""Higher rates for rarer problems."""
128+
129+
@property
130+
def problems_and_rates(self):
131+
"""Returns a list of (weight, also_reverse, problem_class) triples."""
132+
return [
133+
(1.0, True, wiki_lm.LanguagemodelDeEnFrRoWiki64kFitbPacked1k),
134+
(3.0, True, translate_ende.TranslateEndeWmtMulti64kPacked1k),
135+
(1.0, True, translate_enfr.TranslateEnfrWmtMulti64kPacked1k),
136+
(100.0, True, translate_enro.TranslateEnroWmtMultiTiny64kPacked1k),
137+
(1.0, True, cnn_dailymail.SummarizeCnnDailymailMulti64kPacked1k),
138+
(10.0, False, multinli.MultiNLIText2textMulti64kPacked1k),
139+
(10.0, False, squad.SquadText2textMulti64kPacked1k),
140+
]
141+
142+
120143
@registry.register_problem
121144
class LanguagemodelEnWikiLMMultiNLISubwords64k(multi_problem.MultiProblem):
122145
"""Wiki LM and MNLI mixed problem class."""

tensor2tensor/models/mtf_transformer2.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,16 @@ def import_feature(key):
136136
return self._import_feature(features, mesh, key)
137137
targets = import_feature("targets")
138138
sequence_id = import_feature("targets_segmentation")
139-
position = import_feature("targets_position")
139+
if hparams.use_global_position_in_packed_sequence:
140+
position = None
141+
else:
142+
position = import_feature("targets_position")
140143
if self.autoregressive:
141144
inputs = mtf.shift(
142145
targets, offset=1, dim=self.length_dim, wrap=False)
143-
if position is not None:
144-
# first input in later sequences should be 0
145-
inputs *= mtf.to_int32(mtf.not_equal(position, 0))
146+
# We should have a 0 at the beginning of each sequence rather than the
147+
# shifted EOS (1) from the previous sequence.
148+
inputs -= mtf.to_int32(mtf.equal(inputs, 1))
146149
else:
147150
inputs = import_feature("inputs")
148151
# TODO(noam): options for bert-style masking here?
@@ -248,8 +251,12 @@ def import_feature(key):
248251
decoder_sequence_id = import_feature("targets_segmentation")
249252
if decoder_sequence_id is None:
250253
decoder_sequence_id = mtf.to_int32(mtf.not_equal(targets, 0))
251-
encoder_position = import_feature("inputs_position")
252-
decoder_position = import_feature("targets_position")
254+
if hparams.use_global_position_in_packed_sequence:
255+
encoder_position = None
256+
decoder_position = None
257+
else:
258+
encoder_position = import_feature("inputs_position")
259+
decoder_position = import_feature("targets_position")
253260
model = self.model()
254261
logits, loss = model.call_simple(
255262
inputs=inputs,
@@ -349,7 +356,7 @@ def layer_stack_from_hparams(hparams, prefix):
349356
"""Create a layer stack based on the hyperparameter values."""
350357
layers = hparams.get(prefix + "layers")
351358
return transformer.LayerStack(
352-
[layers_registry.get(l)(hparams, prefix) for l in layers],
359+
[layers_registry[l](hparams, prefix) for l in layers],
353360
dropout_rate=hparams.layer_prepostprocess_dropout,
354361
norm_epsilon=hparams.norm_epsilon)
355362

@@ -418,6 +425,14 @@ def mtf_transformer2_base():
418425
"targets": modalities.ModalityType.IDENTITY_SYMBOL,
419426
}
420427
hparams.add_hparam("beam_size", 1)
428+
429+
# If this is True, then in a packed dataset (where exaples are concatenated
430+
# to form longer examples) we use the global position (within the concatenated
431+
# sequence) to compute the positional embedding, instead of the position
432+
# within the individual sequence. This is counterintuitive, but for some
433+
# reason, it keeps the model from diverging.
434+
hparams.add_hparam("use_global_position_in_packed_sequence", True)
435+
421436
return hparams
422437

423438

@@ -837,12 +852,3 @@ def mtr_tr_ende_deep():
837852
hparams.encoder_num_layers = 12
838853
hparams.decoder_num_layers = 12
839854
return hparams
840-
841-
842-
@registry.register_hparams
843-
def ogm_dense_0():
844-
hparams = mtr_tr_dense(0)
845-
hparams.max_length = 1024
846-
hparams.batch_size = 128
847-
hparams.shared_embedding_and_softmax_weights = True
848-
return hparams

tensor2tensor/utils/trainer_lib.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def create_session_config(log_device_placement=False,
132132
gpu_options=gpu_options,
133133
log_device_placement=log_device_placement,
134134
inter_op_parallelism_threads=inter_op_parallelism_threads,
135-
intra_op_parallelism_threads=intra_op_parallelism_threads)
135+
intra_op_parallelism_threads=intra_op_parallelism_threads,
136+
isolate_session_state=True)
136137
return config
137138

138139

0 commit comments

Comments
 (0)