Skip to content

Commit

Permalink
Add phrase_table translation argument (OpenNMT#1370)
Browse files Browse the repository at this point in the history
* Add phrase_table translation argument

If phrase_table is provided (with replace_unk), it will look up the identified source token and give the corresponding target token. If it is not provided (or the identified source token does not exist in the table), then it will copy the source token.
  • Loading branch information
ymoslem authored and vince62s committed Mar 28, 2019
1 parent b7a8c21 commit f09cc8c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
15 changes: 10 additions & 5 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,12 +632,17 @@ def translate_opts(parser):
group.add('--replace_unk', '-replace_unk', action="store_true",
help="Replace the generated UNK tokens with the "
"source token that had highest attention weight. If "
"phrase_table is provided, it will lookup the "
"phrase_table is provided, it will look up the "
"identified source token and give the corresponding "
"target token. If it is not provided(or the identified "
"source token does not exist in the table) then it "
"will copy the source token")

"target token. If it is not provided (or the identified "
"source token does not exist in the table), then it "
"will copy the source token.")
group.add('--phrase_table', '-phrase_table', type=str, default="",
help="If phrase_table is provided (with replace_unk), it will "
"look up the identified source token and give the "
"corresponding target token. If it is not provided "
"(or the identified source token does not exist in "
"the table), then it will copy the source token.")
group = parser.add_argument_group('Logging')
group.add('--verbose', '-verbose', action="store_true",
help='Print scores and predictions for each sentence')
Expand Down
8 changes: 7 additions & 1 deletion onmt/translate/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ class TranslationBuilder(object):
"""

def __init__(self, data, fields, n_best=1, replace_unk=False,
has_tgt=False):
has_tgt=False, phrase_table=""):
self.data = data
self.fields = fields
self._has_text_src = isinstance(
dict(self.fields)["src"], TextMultiField)
self.n_best = n_best
self.replace_unk = replace_unk
self.phrase_table = phrase_table
self.has_tgt = has_tgt

def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn):
Expand All @@ -48,6 +49,11 @@ def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn):
if tokens[i] == tgt_field.unk_token:
_, max_index = attn[i].max(0)
tokens[i] = src_raw[max_index.item()]
if self.phrase_table != "":
with open(self.phrase_table, "r") as f:
for line in f:
if line.startswith(src_raw[max_index.item()]):
tokens[i] = line.split('|||')[1].strip()
return tokens

def from_batch(self, translation_batch):
Expand Down
9 changes: 7 additions & 2 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
block_ngram_repeat=0,
ignore_when_blocking=frozenset(),
replace_unk=False,
phrase_table="",
data_type="text",
verbose=False,
report_bleu=False,
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(
if self.replace_unk and not self.model.decoder.attentional:
raise ValueError(
"replace_unk requires an attentional decoder.")
self.phrase_table = phrase_table
self.data_type = data_type
self.verbose = verbose
self.report_bleu = report_bleu
Expand Down Expand Up @@ -229,6 +231,7 @@ def from_opt(
block_ngram_repeat=opt.block_ngram_repeat,
ignore_when_blocking=set(opt.ignore_when_blocking),
replace_unk=opt.replace_unk,
phrase_table=opt.phrase_table,
data_type=opt.data_type,
verbose=opt.verbose,
report_bleu=opt.report_bleu,
Expand Down Expand Up @@ -264,7 +267,8 @@ def translate(
tgt=None,
src_dir=None,
batch_size=None,
attn_debug=False):
attn_debug=False,
phrase_table=""):
"""Translate content of ``src`` and get gold scores from ``tgt``.
Args:
Expand Down Expand Up @@ -307,7 +311,8 @@ def translate(
)

xlation_builder = onmt.translate.TranslationBuilder(
data, self.fields, self.n_best, self.replace_unk, tgt
data, self.fields, self.n_best, self.replace_unk, tgt,
self.phrase_table
)

# Statistics
Expand Down

0 comments on commit f09cc8c

Please sign in to comment.