Skip to content
This repository was archived by the owner on Dec 11, 2023. It is now read-only.

Commit 438e29a

Browse files
committed
Add language_model flag to train a language model by ignoring the encoder.
PiperOrigin-RevId: 182426171
1 parent 9be88ea commit 438e29a

File tree

4 files changed

+35
-4
lines changed

4 files changed

+35
-4
lines changed

nmt/model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def _set_params_initializer(self,
179179
self.tgt_vocab_size = hparams.tgt_vocab_size
180180
self.num_gpus = hparams.num_gpus
181181
self.time_major = hparams.time_major
182+
self.dtype = tf.float32
182183

183184
# extra_args: to make it flexible for adding external customizable code
184185
self.single_cell_fn = None
@@ -347,11 +348,14 @@ def build_graph(self, hparams, scope=None):
347348
bahdanau | normed_bahdanau).
348349
"""
349350
utils.print_out("# creating %s graph ..." % self.mode)
350-
dtype = tf.float32
351351

352-
with tf.variable_scope(scope or "dynamic_seq2seq", dtype=dtype):
352+
with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype):
353353
# Encoder
354-
self.encoder_outputs, encoder_state = self._build_encoder(hparams)
354+
if hparams.language_model: # no encoder for language modeling
355+
self.encoder_outputs = None
356+
encoder_state = None
357+
else:
358+
self.encoder_outputs, encoder_state = self._build_encoder(hparams)
355359

356360
## Decoder
357361
logits, sample_id, final_context_state = self._build_decoder(
@@ -737,6 +741,12 @@ def _build_decoder_cell(self, hparams, encoder_outputs, encoder_state,
737741
base_gpu=base_gpu
738742
)
739743

744+
if hparams.language_model:
745+
encoder_state = cell.zero_state(self.batch_size, self.dtype)
746+
elif not hparams.pass_hidden_state:
747+
raise ValueError("For non-attentional model, "
748+
"pass_hidden_state needs to be set to True")
749+
740750
# For beam search, we need to replicate encoder infos beam_width times
741751
if self.mode == tf.contrib.learn.ModeKeys.INFER and hparams.beam_width > 0:
742752
decoder_initial_state = tf.contrib.seq2seq.tile_batch(

nmt/nmt.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ def add_arguments(parser):
240240
Average the last N checkpoints for external evaluation.
241241
N can be controlled by setting --num_keep_ckpts.\
242242
"""))
243+
parser.add_argument("--language_model", type="bool", nargs="?",
244+
const=True, default=False,
245+
help="True to train a language model, ignoring encoder")
243246

244247
# Inference
245248
parser.add_argument("--ckpt", type=str, default="",
@@ -369,6 +372,7 @@ def create_hparams(flags):
369372
override_loaded_hparams=flags.override_loaded_hparams,
370373
num_keep_ckpts=flags.num_keep_ckpts,
371374
avg_ckpts=flags.avg_ckpts,
375+
language_model=flags.language_model,
372376
num_intra_threads=flags.num_intra_threads,
373377
num_inter_threads=flags.num_inter_threads,
374378
)
@@ -429,6 +433,16 @@ def extend_hparams(hparams):
429433
_add_argument(hparams, "num_decoder_residual_layers",
430434
num_decoder_residual_layers)
431435

436+
# Language modeling
437+
if hparams.language_model:
438+
hparams.attention = ""
439+
hparams.attention_architecture = ""
440+
hparams.pass_hidden_state = False
441+
hparams.share_vocab = True
442+
hparams.src = hparams.tgt
443+
utils.print_out("For language modeling, we turn off attention and "
444+
"pass_hidden_state; turn on share_vocab; set src to tgt.")
445+
432446
## Vocab
433447
# Get vocab file names first
434448
if hparams.vocab_prefix:
@@ -464,10 +478,13 @@ def extend_hparams(hparams):
464478
_add_argument(hparams, "src_vocab_file", src_vocab_file)
465479
_add_argument(hparams, "tgt_vocab_file", tgt_vocab_file)
466480

467-
# Pretrained Embeddings:
481+
# Pretrained Embeddings
468482
_add_argument(hparams, "src_embed_file", "")
469483
_add_argument(hparams, "tgt_embed_file", "")
470484
if hparams.embed_prefix:
485+
hparams.num_embeddings_partitions = 1
486+
utils.print_out(
487+
"For pretrained embeddings, set num_embeddings_partitions to 1")
471488
src_embed_file = hparams.embed_prefix + "." + hparams.src
472489
tgt_embed_file = hparams.embed_prefix + "." + hparams.tgt
473490

nmt/utils/common_test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def create_test_hparams(unit_type="lstm",
7373
# Misc
7474
standard_hparams.forget_bias = 0.0
7575
standard_hparams.random_seed = 3
76+
language_model=False
7677

7778
# Vocab
7879
standard_hparams.src_vocab_size = 5

nmt/utils/standard_hparams_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,7 @@ def create_standard_hparams():
101101
infer_batch_size=32,
102102
sampling_temperature=0.0,
103103
num_translations_per_input=1,
104+
105+
# Language model
106+
language_model=False,
104107
)

0 commit comments

Comments
 (0)