Skip to content

Commit

Permalink
Added a check for vocabularies loaded from json. (awslabs#596)
Browse files Browse the repository at this point in the history
  • Loading branch information
tdomhan authored and fhieber committed Dec 11, 2018
1 parent 2d61c6e commit 8d37ac3
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
31 changes: 31 additions & 0 deletions sockeye/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,36 @@ def vocab_to_json(vocab: Vocab, path: str):
logger.info('Vocabulary saved to "%s"', path)


def is_valid_vocab(vocab: Vocab) -> bool:
"""
Checks if a vocabulary is valid. We define valid as:
1. All indices from 0 to num_words - 1 are present without duplicates.
2. All special symbols C.PAD_SYMBOL, C.UNK_SYMBOL, C.BOS_SYMBOL, C.EOS_SYMBOL are present.
3. PAD_ID has word id 0.
"""
for symbol in [C.PAD_SYMBOL, C.UNK_SYMBOL, C.BOS_SYMBOL, C.EOS_SYMBOL]:
if symbol not in vocab:
logger.warning("%s missing from vocabulary.", symbol)
return False
if vocab[C.PAD_SYMBOL] != 0:
logger.warning("PAD_ID does not have word id 0 in vocabulary.")
return False
word_ids = []
for word, word_id in vocab.items():
word_ids.append(word_id)
word_ids_set = set(word_ids)
if len(word_ids_set) != len(word_ids):
logger.warning("Duplicate word_ids in vocabulary.")
return False

expected_word_ids = set(range(0, len(vocab)))
if expected_word_ids != word_ids_set:
logger.warning("Not all word_ids from 0 to len(vocabulary) present in vocabulary.")
return False

return True


def vocab_from_json(path: str, encoding: str = C.VOCAB_ENCODING) -> Vocab:
"""
Saves vocabulary in json format.
Expand All @@ -122,6 +152,7 @@ def vocab_from_json(path: str, encoding: str = C.VOCAB_ENCODING) -> Vocab:
"""
with open(path, encoding=encoding) as inp:
vocab = json.load(inp)
utils.check_condition(is_valid_vocab(vocab), "Vocabulary %s not valid." % path)
logger.info('Vocabulary (%d words) loaded from "%s"', len(vocab), path)
return vocab

Expand Down
23 changes: 22 additions & 1 deletion test/unit/test_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest

import sockeye.constants as C
from sockeye.vocab import build_vocab, get_ordered_tokens_from_vocab
from sockeye.vocab import build_vocab, get_ordered_tokens_from_vocab, is_valid_vocab

test_vocab = [
# Example 1
Expand Down Expand Up @@ -79,3 +79,24 @@ def test_constants_in_vocab(data, size, min_count, constants):
def test_get_ordered_tokens_from_vocab(vocab, expected_output):
assert get_ordered_tokens_from_vocab(vocab) == expected_output


@pytest.mark.parametrize(
"vocab, expected_result",
[
({symbol: idx for idx, symbol in enumerate(C.VOCAB_SYMBOLS + ["w1", "w2"])}, True),
# A vocabulary with just the valid symbols doesn't make much sense but is technically valid
({symbol: idx for idx, symbol in enumerate(C.VOCAB_SYMBOLS)}, True),
# Manually specifying the list of required special symbol so that we avoid making a backwards-incompatible
# change by adding a new symbol to C.VOCAB_SYMBOLS
({symbol: idx for idx, symbol in enumerate([C.PAD_SYMBOL, C.UNK_SYMBOL, C.BOS_SYMBOL, C.EOS_SYMBOL])}, True),
# PAD_ID must have word id 0
({symbol: idx for idx, symbol in enumerate(reversed(C.VOCAB_SYMBOLS))}, False),
({symbol: idx for idx, symbol in enumerate(list(reversed(C.VOCAB_SYMBOLS)) + ["w1", "w2"])}, False),
# If there is a gap the vocabulary is not valid:
({symbol: idx if symbol != "w2" else idx + 1 for idx, symbol in enumerate(C.VOCAB_SYMBOLS + ["w1", "w2"])}, False),
# There shouldn't be any duplicate word ids
({symbol: idx if symbol != "w2" else idx - 1 for idx, symbol in enumerate(C.VOCAB_SYMBOLS + ["w1", "w2"])}, False),
]
)
def test_verify_valid_vocab(vocab, expected_result):
assert is_valid_vocab(vocab) == expected_result

0 comments on commit 8d37ac3

Please sign in to comment.