Skip to content

Commit

Permalink
Improvements to glove2dict, get_vocab, create_pretrained_embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
cgpotts committed Apr 1, 2020
1 parent 7a692a7 commit f09d481
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 18 deletions.
39 changes: 35 additions & 4 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,28 +59,59 @@ def test_glove2dict():
data = utils.glove2dict(src_filename)
assert len(data) == 400000

@pytest.mark.parametrize("X, n_words, expected", [
@pytest.mark.parametrize("X, n_words, mincount, expected", [
[
[["a", "b", "c"], ["b", "c", "d"]],
None,
1,
["$UNK", "a", "b", "c", "d"]
],
[
[["a", "b", "c"], ["b", "c", "d"]],
2,
1,
["$UNK", "b", "c"]
],
[
[],
2,
1,
["$UNK"]
]
],
[
[["a", "b", "b"], ["b", "c", "a"]],
None,
3,
["$UNK", "b"]
],
[
[["b", "b", "b"], ["b", "a", "a", "c"]],
2,
3,
["$UNK", "b"]
],
])
def test_get_vocab(X, n_words, expected):
result = utils.get_vocab(X, n_words=n_words)
def test_get_vocab(X, n_words, mincount, expected):
result = utils.get_vocab(X, n_words=n_words, mincount=mincount)
assert result == expected


@pytest.mark.parametrize("lookup, vocab, required_tokens, expected_shape", [
[
{"a": [1,2]}, ["a", "b"], ["$UNK"], (3,2)
],
[
{"a": [1,2], "b": [3,4]}, ["b"], ["$UNK"], (2,2)
]
])
def test_create_pretrained_embedding(lookup, vocab, required_tokens, expected_shape):
result, new_vocab = utils.create_pretrained_embedding(lookup, vocab, required_tokens)
assert result.shape == expected_shape
assert "$UNK" in new_vocab
new_vocab.remove("$UNK")
assert vocab == new_vocab


@pytest.mark.parametrize("set_value", [True, False])
def test_fix_random_seeds_system(set_value):
params = dict(
Expand Down
38 changes: 24 additions & 14 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,30 @@

def glove2dict(src_filename):
"""GloVe Reader.
Parameters
----------
src_filename : str
Full path to the GloVe file to be processed.
Returns
-------
dict
Mapping words to their GloVe vectors.
Mapping words to their GloVe vectors as `np.array`.
"""
# This distribution has some words with spaces, so we have to
# assume its dimensionality and parse out the lines specially:
if '840B.300d' in src_filename:
line_parser = lambda line: line.rsplit(" ", 300)
else:
line_parser = lambda line: line.strip().split()
data = {}
with open(src_filename, encoding='utf8') as f:
while True:
try:
line = next(f)
line = line.strip().split()
line = line_parser(line)
data[line[0]] = np.array(line[1: ], dtype=np.float)
except StopIteration:
break
Expand Down Expand Up @@ -173,7 +182,7 @@ def fit_classifier_with_crossvalidation(
return crossvalidator.best_estimator_


def get_vocab(X, n_words=None):
def get_vocab(X, n_words=None, mincount=1):
"""Get the vocabulary for an RNN example matrix `X`,
adding $UNK$ if it isn't already present.
Expand All @@ -182,6 +191,8 @@ def get_vocab(X, n_words=None):
X : list of lists of str
n_words : int or None
If this is `int > 0`, keep only the top `n_words` by frequency.
mincount : int
Only words with at least this many tokens are kept.
Returns
-------
Expand All @@ -190,14 +201,18 @@ def get_vocab(X, n_words=None):
"""
wc = Counter([w for ex in X for w in ex])
wc = wc.most_common(n_words) if n_words else wc.items()
vocab = {w for w, c in wc}
if mincount > 1:
wc = {(w, c) for w, c in wc if c >= mincount}
vocab = {w for w, _ in wc}
vocab.add("$UNK")
return sorted(vocab)


def create_pretrained_embedding(
lookup, vocab, required_tokens=('$UNK', "<s>", "</s>")):
"""Create an embedding matrix from a lookup and a specified vocab.
Words from `vocab` that are not in `lookup` are given random
representations.
Parameters
----------
Expand All @@ -212,21 +227,16 @@ def create_pretrained_embedding(
Returns
-------
np.array, list
The np.array is an embedding for `vocab`, restricted to words
that are in in `lookup`, and sorted alphabetically. The last
vector is for $UNK if it is not already in both `lookup`
and `vocab`. The list is the updated vocabulary. The words are
sorted alphabetically, to align with the embedding, and $UNK is
appended the end if it was not already in in both `lookup` and
`vocab`.
The np.array is an embedding for `vocab` and the `list` is
the potentially expanded version of `vocab` that came in.
"""
vocab = sorted(set(lookup) & set(vocab))
embedding = np.array([lookup[w] for w in vocab])
dim = len(next(iter(lookup.values())))
embedding = np.array([lookup.get(w, randvec(dim)) for w in vocab])
for tok in required_tokens:
if tok not in vocab:
vocab.append(tok)
embedding = np.vstack((embedding, randvec(embedding.shape[1])))
embedding = np.vstack((embedding, randvec(dim)))
return embedding, vocab


Expand Down

0 comments on commit f09d481

Please sign in to comment.