Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make inputs directly from a dict #591

Merged
merged 5 commits into from
Nov 30, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.62]
### Added
- Add option to make `TranslatorInputs` directly from a dict.

## [1.18.61]
### Changed
- Update to MXNet 1.3.1. Removed requirements/requirements.gpu-cu{75,91}.txt as CUDA 7.5 and 9.1 are deprecated.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.61'
__version__ = '1.18.62'
34 changes: 27 additions & 7 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,33 +694,53 @@ def make_input_from_json_string(sentence_id: SentenceId, json_string: str) -> Tr
:param sentence_id: Sentence id.
:param json_string: A JSON object serialized as a string that must contain a key "text", mapping to the input text,
and optionally a key "factors" that maps to a list of strings, each of which representing a factor sequence
for the input text.
for the input text. Constraints and an avoid list can also be added through the "constraints" and "avoid"
keys.
:return: A TranslatorInput.
"""
try:
jobj = json.loads(json_string, encoding=C.JSON_ENCODING)
tokens = jobj[C.JSON_TEXT_KEY]
return make_input_from_dict(jobj)

except Exception as e:
logger.exception(e, exc_info=True) if not is_python34() else logger.error(e) # type: ignore
return _bad_input(sentence_id, reason=json_string)


def make_input_from_dict(sentence_id: SentenceId, input_dict: Dict) -> TranslatorInput:
"""
Returns a TranslatorInput object from a JSON object, serialized as a string.

:param sentence_id: Sentence id.
:param input_dict: A dict that must contain a key "text", mapping to the input text, and optionally a key "factors"
that maps to a list of strings, each of which representing a factor sequence for the input text.
Constraints and an avoid list can also be added through the "constraints" and "avoid" keys.
:return: A TranslatorInput.
"""
try:
tokens = input_dict[C.JSON_TEXT_KEY]
tokens = list(data_io.get_tokens(tokens))
factors = jobj.get(C.JSON_FACTORS_KEY)
factors = input_dict.get(C.JSON_FACTORS_KEY)
if isinstance(factors, list):
factors = [list(data_io.get_tokens(factor)) for factor in factors]
lengths = [len(f) for f in factors]
if not all(length == len(tokens) for length in lengths):
logger.error("Factors have different length than input text: %d vs. %s", len(tokens), str(lengths))
return _bad_input(sentence_id, reason=json_string)
return _bad_input(sentence_id, reason=input_dict)

# List of phrases to prevent from occuring in the output
avoid_list = jobj.get(C.JSON_AVOID_KEY)
avoid_list = input_dict.get(C.JSON_AVOID_KEY)

# List of phrases that must appear in the output
constraints = jobj.get(C.JSON_CONSTRAINTS_KEY)
constraints = input_dict.get(C.JSON_CONSTRAINTS_KEY)

# If there is overlap between positive and negative constraints, assume the user wanted
# the words, and so remove them from the avoid_list (negative constraints)
if constraints is not None and avoid_list is not None:
avoid_set = set(avoid_list)
overlap = set(constraints).intersection(avoid_set)
if len(overlap) > 0:
logger.warning("Overlap between constraints and avoid set, dropping the overlapping avoids")
avoid_list = list(avoid_set.difference(overlap))

# Convert to a list of tokens
Expand All @@ -733,7 +753,7 @@ def make_input_from_json_string(sentence_id: SentenceId, json_string: str) -> Tr

except Exception as e:
logger.exception(e, exc_info=True) if not is_python34() else logger.error(e) # type: ignore
return _bad_input(sentence_id, reason=json_string)
return _bad_input(sentence_id, reason=input_dict)


def make_input_from_factored_string(sentence_id: SentenceId,
Expand Down
26 changes: 26 additions & 0 deletions test/unit/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,32 @@ def test_failed_make_input_from_valid_json_string(text, text_key, factors, facto
assert isinstance(inp, sockeye.inference.BadTranslatorInput)


@pytest.mark.parametrize("text, factors",
[("this is a test without factors", None),
("", None),
("test", ["X", "X"]),
("a b c", ["x y z"]),
("a", [])])
def test_make_input_from_valid_dict(text, factors):
sentence_id = 1
expected_tokens = list(sockeye.data_io.get_tokens(text))
inp = sockeye.inference.make_input_from_dict(sentence_id, {C.JSON_TEXT_KEY: text,
C.JSON_FACTORS_KEY: factors})
assert len(inp) == len(expected_tokens)
assert inp.tokens == expected_tokens
if factors is not None:
assert len(inp.factors) == len(factors)
else:
assert inp.factors is None


@pytest.mark.parametrize("text, text_key, factors, factors_key", [("a", "blub", None, "")])
def test_failed_make_input_from_valid_dict(text, text_key, factors, factors_key):
sentence_id = 1
inp = sockeye.inference.make_input_from_dict(sentence_id, {text_key: text, factors_key: factors})
assert isinstance(inp, sockeye.inference.BadTranslatorInput)


@pytest.mark.parametrize("strings",
[
["a b c"],
Expand Down