Skip to content

Add phrase_table translation argument #1370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,17 @@ def translate_opts(parser):
group.add('--replace_unk', '-replace_unk', action="store_true",
help="Replace the generated UNK tokens with the "
"source token that had highest attention weight. If "
"phrase_table is provided, it will lookup the "
"phrase_table is provided, it will look up the "
"identified source token and give the corresponding "
"target token. If it is not provided(or the identified "
"source token does not exist in the table) then it "
"will copy the source token")

"target token. If it is not provided (or the identified "
"source token does not exist in the table), then it "
"will copy the source token.")
group.add('--phrase_table', '-phrase_table', type=str, default="",
help="If phrase_table is provided (with replace_unk), it will "
"look up the identified source token and give the "
"corresponding target token. If it is not provided "
"(or the identified source token does not exist in "
"the table), then it will copy the source token.")
group = parser.add_argument_group('Logging')
group.add('--verbose', '-verbose', action="store_true",
help='Print scores and predictions for each sentence')
Expand Down
8 changes: 7 additions & 1 deletion onmt/translate/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ class TranslationBuilder(object):
"""

def __init__(self, data, fields, n_best=1, replace_unk=False,
has_tgt=False):
has_tgt=False, phrase_table=""):
self.data = data
self.fields = fields
self._has_text_src = isinstance(
dict(self.fields)["src"], TextMultiField)
self.n_best = n_best
self.replace_unk = replace_unk
self.phrase_table = phrase_table
self.has_tgt = has_tgt

def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn):
Expand All @@ -48,6 +49,11 @@ def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn):
if tokens[i] == tgt_field.unk_token:
_, max_index = attn[i].max(0)
tokens[i] = src_raw[max_index.item()]
if self.phrase_table != "":
with open(self.phrase_table, "r") as f:
for line in f:
if line.startswith(src_raw[max_index.item()]):
tokens[i] = line.split('|||')[1].strip()
return tokens

def from_batch(self, translation_batch):
Expand Down
9 changes: 7 additions & 2 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
block_ngram_repeat=0,
ignore_when_blocking=frozenset(),
replace_unk=False,
phrase_table="",
data_type="text",
verbose=False,
report_bleu=False,
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(
if self.replace_unk and not self.model.decoder.attentional:
raise ValueError(
"replace_unk requires an attentional decoder.")
self.phrase_table = phrase_table
self.data_type = data_type
self.verbose = verbose
self.report_bleu = report_bleu
Expand Down Expand Up @@ -229,6 +231,7 @@ def from_opt(
block_ngram_repeat=opt.block_ngram_repeat,
ignore_when_blocking=set(opt.ignore_when_blocking),
replace_unk=opt.replace_unk,
phrase_table=opt.phrase_table,
data_type=opt.data_type,
verbose=opt.verbose,
report_bleu=opt.report_bleu,
Expand Down Expand Up @@ -264,7 +267,8 @@ def translate(
tgt=None,
src_dir=None,
batch_size=None,
attn_debug=False):
attn_debug=False,
phrase_table=""):
"""Translate content of ``src`` and get gold scores from ``tgt``.

Args:
Expand Down Expand Up @@ -307,7 +311,8 @@ def translate(
)

xlation_builder = onmt.translate.TranslationBuilder(
data, self.fields, self.n_best, self.replace_unk, tgt
data, self.fields, self.n_best, self.replace_unk, tgt,
self.phrase_table
)

# Statistics
Expand Down