Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scoring #538

Merged
merged 60 commits into from
Sep 28, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
4fc041f
made validation source optional, parameterized fill_up
mjpost Sep 17, 2018
af1e7bd
moved training-specific checks
mjpost Sep 17, 2018
1322b31
added scoring up to generating outputs, almost finished
mjpost Sep 17, 2018
27dd1cb
works but not polished
mjpost Sep 18, 2018
7e8767a
added length penalty as command-line arguments
mjpost Sep 18, 2018
29ac7a4
added 'repeat_last' fillup strategy (wrong approach)
mjpost Sep 18, 2018
d468770
added zero fill_up strategy, no_permute on batch iterator
mjpost Sep 18, 2018
359bc60
print source sentence with generalized ids2tokens()
mjpost Sep 18, 2018
ebd4e10
fixed test cases
mjpost Sep 18, 2018
8994f5c
Merge branch 'master' into scoring
mjpost Sep 18, 2018
0079c71
style checks
mjpost Sep 18, 2018
210a123
moved training-specific check to training
mjpost Sep 18, 2018
5015d54
context
mjpost Sep 18, 2018
48d00cc
turned on bucketing
mjpost Sep 18, 2018
927b92b
turned off more data permuting
mjpost Sep 18, 2018
581805e
set batch size to reasonable default for fully-unrolled graph
mjpost Sep 19, 2018
aa417dd
pulled out bucketing args, removed options from scoring
mjpost Sep 19, 2018
28a5e4b
added --score-type and --output
mjpost Sep 19, 2018
ddd5d88
Merge branch 'master' into scoring
mjpost Sep 19, 2018
307e56a
documentation
mjpost Sep 19, 2018
98bb1f1
style check
mjpost Sep 19, 2018
c1f4035
versioning
mjpost Sep 19, 2018
6118643
fixed easy issues from @fhieber's CR
mjpost Sep 20, 2018
16864dd
Merge remote-tracking branch 'github/master' into scoring
mjpost Sep 20, 2018
6a73883
removed batch preparation
mjpost Sep 20, 2018
bd2a526
moved summing to the computation graph
mjpost Sep 20, 2018
717f18b
now using output_handler
mjpost Sep 20, 2018
fa1d9b8
added warning if sentences get skipped
mjpost Sep 20, 2018
47ddceb
typo
mjpost Sep 20, 2018
6f1c4e8
cleanup and docs
mjpost Sep 20, 2018
ff324fb
documentation and cleanup
mjpost Sep 21, 2018
8805a62
Simplified fill_up, improved comments
mjpost Sep 22, 2018
aef30ae
bugfix with length penalty alpha = 0
mjpost Sep 22, 2018
0ff7feb
bugfixes in names
mjpost Sep 22, 2018
39294ac
combined run_forward() and get_outputs()
mjpost Sep 22, 2018
bbd8c13
uncontroversial reversions
mjpost Sep 24, 2018
9f0a71c
Merge remote-tracking branch 'github/master' into scoring
mjpost Sep 25, 2018
ef6b9f3
stuffing args for scoring, no_permute renamed permute
mjpost Sep 25, 2018
9e1736d
added --output and --softmax-temperature
mjpost Sep 25, 2018
5269a9c
changed float width to 3
mjpost Sep 25, 2018
09c46c2
removed prepare data option, reverted train.py
mjpost Sep 25, 2018
fb601a3
added test cases, bugfix with source factors
mjpost Sep 25, 2018
e049fea
bugfix — get first item from group
mjpost Sep 25, 2018
5f950be
style fixes
mjpost Sep 25, 2018
3c9401d
style changes
mjpost Sep 25, 2018
7c48db9
missed one
mjpost Sep 25, 2018
68f943c
don't score when --skip-topk
mjpost Sep 25, 2018
7db9bef
debugging travis
mjpost Sep 25, 2018
d67488b
seq max seq len very large to pass test
mjpost Sep 25, 2018
d834a7d
reading maxlen from config and skipping some test lines
mjpost Sep 25, 2018
b731159
debugging output since still failing
mjpost Sep 25, 2018
8abcf1b
skipping test outputs with vocab symbols
mjpost Sep 26, 2018
3ae7d7e
debugging travis
mjpost Sep 26, 2018
f39ba62
more systematic testing for when to try to score
mjpost Sep 26, 2018
776fcbb
don't score when translate beam == 1 or length close to max
mjpost Sep 26, 2018
e9808d2
restored skipping topk
mjpost Sep 26, 2018
065abc1
updated documentation
mjpost Sep 26, 2018
b3dd468
entry point for sockeye.score
mjpost Sep 27, 2018
b23bb01
Merge branch 'master' into scoring
fhieber Sep 27, 2018
2da12bc
proper spacing
fhieber Sep 28, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
added test cases, bugfix with source factors
  • Loading branch information
mjpost committed Sep 25, 2018
commit fb601a3644d532638754d0c0d3d72bdf72c11ff8
2 changes: 1 addition & 1 deletion sockeye/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def score(self,
for source, target, score in zip(batch.data[0], batch.data[1], scores):

# The "zeros" padding method will have filled remainder batches with zeros, so we can skip them here
if source[0] == C.PAD_ID:
if source[0][0] == C.PAD_ID:
break

sentence_no += 1
Expand Down
72 changes: 66 additions & 6 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import sockeye.extract_parameters
import sockeye.lexicon
import sockeye.prepare_data
import sockeye.score
import sockeye.train
import sockeye.translate
import sockeye.utils
Expand Down Expand Up @@ -210,6 +211,10 @@ def tmp_digits_dataset(prefix: str,

_TRANSLATE_PARAMS_RESTRICT = "--restrict-lexicon {lexicon} --restrict-lexicon-topk {topk}"

_SCORE_PARAMS_COMMON = "--use-cpu --model {model} --source {source} --target {target} --output {output}"

_SCORE_WITH_FACTORS_COMMON = " --source-factors {source_factors}"

_EVAL_PARAMS_COMMON = "--hypotheses {hypotheses} --references {references} --metrics {metrics} {quiet}"

_EXTRACT_PARAMS = "--input {input} --names target_output_bias --list-all --output {output}"
Expand Down Expand Up @@ -331,22 +336,39 @@ def run_train_translate(train_params: str,
cp_metrics = cp_decoder.decode_and_evaluate()
logger.info("Checkpoint decoder metrics: %s", cp_metrics)

# import shutil
# shutil.copytree(model_path, "/Users/post/code/sockeye/t/model")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove :)


logger.info("Translating with parameters %s.", translate_params)
# Translate corpus with the 1st params
out_path = os.path.join(work_dir, "out.txt")
params = "{} {} {}".format(sockeye.translate.__file__,
_TRANSLATE_PARAMS_COMMON.format(model=model_path,
input=test_source_path,
output=out_path,
quiet=quiet_arg),
translate_params)
translate_score_path = os.path.join(work_dir, "out.scores.txt")
params = "{} {} {} --output-type translation_with_score".format(sockeye.translate.__file__,
_TRANSLATE_PARAMS_COMMON.format(model=model_path,
input=test_source_path,
output=out_path,
quiet=quiet_arg),
translate_params)

if test_source_factor_paths is not None:
params += _TRANSLATE_WITH_FACTORS_COMMON.format(input_factors=" ".join(test_source_factor_paths))

with patch.object(sys, "argv", params.split()):
sockeye.translate.main()

# Break out translation and score
outputs = open(out_path).readlines()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use with statement for this

with open(out_path, 'w') as out_translate, open(translate_score_path, 'w') as out_scores:
for output in outputs:
output = output.strip()
try:
score, translation = output.split('\t')
except:
mjpost marked this conversation as resolved.
Show resolved Hide resolved
score = output
translation = ""
print(translation, file=out_translate)
print(score, file=out_scores)

# Test target constraints
if use_target_constraints:
"""
Expand Down Expand Up @@ -403,6 +425,44 @@ def run_train_translate(train_params: str,
# for negative constraints, ensure the constraints is *not* in the constrained output
assert restriction not in constrained_out


# Test scoring. We make sure that we can score the (input, translation output) and get the same
# model score.
if not use_prepared_data:
## Score
# We use the translation parameters, but have to remove irrelevant arguments from it.
# Currently, the only relevant flag passed is the --softmax-temperature flag.
score_params = ''
if 'softmax-temperature' in translate_params:
tokens = translate_params.split(C.TOKEN_SEPARATOR)
mjpost marked this conversation as resolved.
Show resolved Hide resolved
for i, token in enumerate(tokens):
if token == '--softmax-temperature':
score_params = '--softmax-temperature {}'.format(tokens[i + 1])
break

scores_output_file = out_path + '.score'
params = "{} {} {}".format(sockeye.score.__file__,
_SCORE_PARAMS_COMMON.format(model=model_path,
source=test_source_path,
target=out_path,
output=scores_output_file),
score_params)

if test_source_factor_paths is not None:
params += _SCORE_WITH_FACTORS_COMMON.format(source_factors=" ".join(test_source_factor_paths))

with patch.object(sys, "argv", params.split()):
sockeye.score.main()

## Compare scored output to original translation output. First remove -inf lines from the translate_score_path
## file. These correspond to blank lines in translate, which are skipped in sockeye.score.
for translate_score, score_score in zip(filter(lambda x: x != '-inf\n', open(translate_score_path).readlines()),
mjpost marked this conversation as resolved.
Show resolved Hide resolved
open(scores_output_file).readlines()):
translate_score = float(translate_score)
score_score = float(score_score)
print('SCORES', translate_score, score_score)
mjpost marked this conversation as resolved.
Show resolved Hide resolved
assert abs(translate_score - score_score) < 0.002
mjpost marked this conversation as resolved.
Show resolved Hide resolved

# Translate corpus with the 2nd params
if translate_params_equiv is not None:
out_path_equiv = os.path.join(work_dir, "out_equiv.txt")
Expand Down