Skip to content

Commit

Permalink
Add unit tests for embeddings module. (OpenNMT#1241)
Browse files Browse the repository at this point in the history
  • Loading branch information
flauted authored and vince62s committed Jan 31, 2019
1 parent e12e8d0 commit 801a3d7
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 4 deletions.
41 changes: 37 additions & 4 deletions onmt/modules/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Embeddings module """
import math
import warnings

import torch
import torch.nn as nn
Expand All @@ -21,6 +22,9 @@ class PositionalEncoding(nn.Module):
"""

def __init__(self, dropout, dim, max_len=5000):
if dim % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(dim))
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
Expand Down Expand Up @@ -80,7 +84,7 @@ class Embeddings(nn.Module):
concat, sum or mlp.
feat_vec_exponent (float): when using `-feat_merge concat`, feature
embedding size is N^feat_dim_exponent, where N is the
number of values of feature takes.
number of values the feature takes.
feat_vec_size (int): embedding dimension for features when using
`-feat_merge mlp`
dropout (float): dropout probability.
Expand All @@ -91,12 +95,15 @@ def __init__(self, word_vec_size,
word_padding_idx,
position_encoding=False,
feat_merge="concat",
feat_vec_exponent=0.7, feat_vec_size=-1,
feat_vec_exponent=0.7,
feat_vec_size=-1,
feat_padding_idx=[],
feat_vocab_sizes=[],
dropout=0,
sparse=False,
fix_word_vecs=False):
self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent,
feat_vec_size, feat_padding_idx)

if feat_padding_idx is None:
feat_padding_idx = []
Expand Down Expand Up @@ -147,8 +154,7 @@ def __init__(self, word_vec_size,

if feat_merge == 'mlp' and len(feat_vocab_sizes) > 0:
in_dim = sum(emb_dims)
out_dim = word_vec_size
mlp = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU())
mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU())
self.make_embedding.add_module('mlp', mlp)

self.position_encoding = position_encoding
Expand All @@ -160,6 +166,33 @@ def __init__(self, word_vec_size,
if fix_word_vecs:
self.word_lut.weight.requires_grad = False

def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent,
feat_vec_size, feat_padding_idx):
if feat_merge == "sum":
# features must use word_vec_size
if feat_vec_exponent != 0.7:
warnings.warn("Merging with sum, but got non-default "
"feat_vec_exponent. It will be unused.")
if feat_vec_size != -1:
warnings.warn("Merging with sum, but got non-default "
"feat_vec_size. It will be unused.")
elif feat_vec_size > 0:
# features will use feat_vec_size
if feat_vec_exponent != -1:
warnings.warn("Not merging with sum and positive "
"feat_vec_size, but got non-default "
"feat_vec_exponent. It will be unused.")
else:
if feat_vec_exponent <= 0:
raise ValueError("Using feat_vec_exponent to determine "
"feature vec size, but got feat_vec_exponent "
"less than or equal to 0.")
n_feats = len(feat_vocab_sizes)
if n_feats != len(feat_padding_idx):
raise ValueError("Got unequal number of feat_vocab_sizes and "
"feat_padding_idx ({:d} != {:d})".format(
n_feats, len(feat_padding_idx)))

@property
def word_lut(self):
""" word look-up table """
Expand Down
153 changes: 153 additions & 0 deletions onmt/tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import unittest
from onmt.modules.embeddings import Embeddings

import itertools
from copy import deepcopy

import torch


def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))


class TestEmbeddings(unittest.TestCase):
INIT_CASES = list(product_dict(
word_vec_size=[172],
word_vocab_size=[319],
word_padding_idx=[17],
position_encoding=[False, True],
feat_merge=["first", "concat", "sum", "mlp"],
feat_vec_exponent=[-1, 1.1, 0.7],
feat_vec_size=[0, 200],
feat_padding_idx=[[], [29], [0, 1]],
feat_vocab_sizes=[[], [39], [401, 39]],
dropout=[0, 0.5],
fix_word_vecs=[False, True]
))
PARAMS = list(product_dict(
batch_size=[1, 14],
max_seq_len=[23]
))

@classmethod
def case_is_degenerate(cls, case):
no_feats = len(case["feat_vocab_sizes"]) == 0
if case["feat_merge"] != "first" and no_feats:
return True
if case["feat_merge"] == "first" and not no_feats:
return True
if case["feat_merge"] == "concat" and case["feat_vec_exponent"] != -1:
return True
if no_feats and case["feat_vec_exponent"] != -1:
return True
if len(case["feat_vocab_sizes"]) != len(case["feat_padding_idx"]):
return True
if case["feat_vec_size"] == 0 and case["feat_vec_exponent"] <= 0:
return True
if case["feat_merge"] == "sum":
if case["feat_vec_exponent"] != -1:
return True
if case["feat_vec_size"] != 0:
return True
if case["feat_vec_size"] != 0 and case["feat_vec_exponent"] != -1:
return True
return False

@classmethod
def cases(cls):
for case in cls.INIT_CASES:
if not cls.case_is_degenerate(case):
yield case

@classmethod
def dummy_inputs(cls, params, init_case):
max_seq_len = params["max_seq_len"]
batch_size = params["batch_size"]
fv_sizes = init_case["feat_vocab_sizes"]
n_words = init_case["word_vocab_size"]
voc_sizes = [n_words] + fv_sizes
pad_idxs = [init_case["word_padding_idx"]] + \
init_case["feat_padding_idx"]
lengths = torch.randint(0, max_seq_len, (batch_size,))
lengths[0] = max_seq_len
inps = torch.empty((max_seq_len, batch_size, len(voc_sizes)),
dtype=torch.long)
for f, (voc_size, pad_idx) in enumerate(zip(voc_sizes, pad_idxs)):
for b, len_ in enumerate(lengths):
inps[:len_, b, f] = torch.randint(0, voc_size-1, (len_,))
inps[len_:, b, f] = pad_idx
return inps

@classmethod
def expected_shape(cls, params, init_case):
wvs = init_case["word_vec_size"]
fvs = init_case["feat_vec_size"]
nf = len(init_case["feat_vocab_sizes"])
size = wvs
if init_case["feat_merge"] not in {"sum", "mlp"}:
size += nf * fvs
return params["max_seq_len"], params["batch_size"], size

def test_embeddings_forward_shape(self):
for params, init_case in itertools.product(self.PARAMS, self.cases()):
emb = Embeddings(**init_case)
dummy_in = self.dummy_inputs(params, init_case)
res = emb(dummy_in)
expected_shape = self.expected_shape(params, init_case)
self.assertEqual(res.shape, expected_shape, init_case.__str__())

def test_embeddings_trainable_params(self):
for params, init_case in itertools.product(self.PARAMS,
self.cases()):
emb = Embeddings(**init_case)
trainable_params = {n: p for n, p in emb.named_parameters()
if p.requires_grad}
# first check there's nothing unexpectedly not trainable
for key in emb.state_dict():
if key not in trainable_params:
if key.endswith("emb_luts.0.weight") and \
init_case["fix_word_vecs"]:
# ok: word embeddings shouldn't be trainable
# if word vecs are fixed
continue
if key.endswith(".pe.pe"):
# ok: positional encodings shouldn't be trainable
assert init_case["position_encoding"]
continue
else:
self.fail("Param {:s} is unexpectedly not "
"trainable.".format(key))
# then check nothing unexpectedly trainable
if init_case["fix_word_vecs"]:
self.assertFalse(
any(trainable_param.endswith("emb_luts.0.weight")
for trainable_param in trainable_params),
"Word embedding is trainable but word vecs are fixed.")
if init_case["position_encoding"]:
self.assertFalse(
any(trainable_p.endswith(".pe.pe")
for trainable_p in trainable_params),
"Positional encoding is trainable.")

def test_embeddings_trainable_params_update(self):
for params, init_case in itertools.product(self.PARAMS, self.cases()):
emb = Embeddings(**init_case)
trainable_params = {n: p for n, p in emb.named_parameters()
if p.requires_grad}
if len(trainable_params) > 0:
old_weights = deepcopy(trainable_params)
dummy_in = self.dummy_inputs(params, init_case)
res = emb(dummy_in)
pretend_loss = res.sum()
pretend_loss.backward()
dummy_optim = torch.optim.SGD(trainable_params.values(), 1)
dummy_optim.step()
for param_name in old_weights.keys():
self.assertTrue(
trainable_params[param_name]
.ne(old_weights[param_name]).any(),
param_name + " " + init_case.__str__())

0 comments on commit 801a3d7

Please sign in to comment.