Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for embeddings module #1241

Merged
merged 1 commit into from
Jan 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions onmt/modules/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Embeddings module """
import math
import warnings

import torch
import torch.nn as nn
Expand All @@ -21,6 +22,9 @@ class PositionalEncoding(nn.Module):
"""

def __init__(self, dropout, dim, max_len=5000):
if dim % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(dim))
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
Expand Down Expand Up @@ -80,7 +84,7 @@ class Embeddings(nn.Module):
concat, sum or mlp.
feat_vec_exponent (float): when using `-feat_merge concat`, feature
embedding size is N^feat_dim_exponent, where N is the
number of values of feature takes.
number of values the feature takes.
feat_vec_size (int): embedding dimension for features when using
`-feat_merge mlp`
dropout (float): dropout probability.
Expand All @@ -91,12 +95,15 @@ def __init__(self, word_vec_size,
word_padding_idx,
position_encoding=False,
feat_merge="concat",
feat_vec_exponent=0.7, feat_vec_size=-1,
feat_vec_exponent=0.7,
feat_vec_size=-1,
feat_padding_idx=[],
feat_vocab_sizes=[],
dropout=0,
sparse=False,
fix_word_vecs=False):
self._validate_args(feat_merge, feat_vocab_sizes, feat_vec_exponent,
feat_vec_size, feat_padding_idx)

if feat_padding_idx is None:
feat_padding_idx = []
Expand Down Expand Up @@ -147,8 +154,7 @@ def __init__(self, word_vec_size,

if feat_merge == 'mlp' and len(feat_vocab_sizes) > 0:
in_dim = sum(emb_dims)
out_dim = word_vec_size
mlp = nn.Sequential(nn.Linear(in_dim, out_dim), nn.ReLU())
mlp = nn.Sequential(nn.Linear(in_dim, word_vec_size), nn.ReLU())
self.make_embedding.add_module('mlp', mlp)

self.position_encoding = position_encoding
Expand All @@ -160,6 +166,33 @@ def __init__(self, word_vec_size,
if fix_word_vecs:
self.word_lut.weight.requires_grad = False

def _validate_args(self, feat_merge, feat_vocab_sizes, feat_vec_exponent,
feat_vec_size, feat_padding_idx):
if feat_merge == "sum":
# features must use word_vec_size
if feat_vec_exponent != 0.7:
warnings.warn("Merging with sum, but got non-default "
"feat_vec_exponent. It will be unused.")
if feat_vec_size != -1:
warnings.warn("Merging with sum, but got non-default "
"feat_vec_size. It will be unused.")
elif feat_vec_size > 0:
# features will use feat_vec_size
if feat_vec_exponent != -1:
warnings.warn("Not merging with sum and positive "
"feat_vec_size, but got non-default "
"feat_vec_exponent. It will be unused.")
else:
if feat_vec_exponent <= 0:
raise ValueError("Using feat_vec_exponent to determine "
"feature vec size, but got feat_vec_exponent "
"less than or equal to 0.")
n_feats = len(feat_vocab_sizes)
if n_feats != len(feat_padding_idx):
raise ValueError("Got unequal number of feat_vocab_sizes and "
"feat_padding_idx ({:d} != {:d})".format(
n_feats, len(feat_padding_idx)))

@property
def word_lut(self):
""" word look-up table """
Expand Down
153 changes: 153 additions & 0 deletions onmt/tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import unittest
from onmt.modules.embeddings import Embeddings

import itertools
from copy import deepcopy

import torch


def product_dict(**kwargs):
keys = kwargs.keys()
vals = kwargs.values()
for instance in itertools.product(*vals):
yield dict(zip(keys, instance))


class TestEmbeddings(unittest.TestCase):
INIT_CASES = list(product_dict(
word_vec_size=[172],
word_vocab_size=[319],
word_padding_idx=[17],
position_encoding=[False, True],
feat_merge=["first", "concat", "sum", "mlp"],
feat_vec_exponent=[-1, 1.1, 0.7],
feat_vec_size=[0, 200],
feat_padding_idx=[[], [29], [0, 1]],
feat_vocab_sizes=[[], [39], [401, 39]],
dropout=[0, 0.5],
fix_word_vecs=[False, True]
))
PARAMS = list(product_dict(
batch_size=[1, 14],
max_seq_len=[23]
))

@classmethod
def case_is_degenerate(cls, case):
no_feats = len(case["feat_vocab_sizes"]) == 0
if case["feat_merge"] != "first" and no_feats:
return True
if case["feat_merge"] == "first" and not no_feats:
return True
if case["feat_merge"] == "concat" and case["feat_vec_exponent"] != -1:
return True
if no_feats and case["feat_vec_exponent"] != -1:
return True
if len(case["feat_vocab_sizes"]) != len(case["feat_padding_idx"]):
return True
if case["feat_vec_size"] == 0 and case["feat_vec_exponent"] <= 0:
return True
if case["feat_merge"] == "sum":
if case["feat_vec_exponent"] != -1:
return True
if case["feat_vec_size"] != 0:
return True
if case["feat_vec_size"] != 0 and case["feat_vec_exponent"] != -1:
return True
return False

@classmethod
def cases(cls):
for case in cls.INIT_CASES:
if not cls.case_is_degenerate(case):
yield case

@classmethod
def dummy_inputs(cls, params, init_case):
max_seq_len = params["max_seq_len"]
batch_size = params["batch_size"]
fv_sizes = init_case["feat_vocab_sizes"]
n_words = init_case["word_vocab_size"]
voc_sizes = [n_words] + fv_sizes
pad_idxs = [init_case["word_padding_idx"]] + \
init_case["feat_padding_idx"]
lengths = torch.randint(0, max_seq_len, (batch_size,))
lengths[0] = max_seq_len
inps = torch.empty((max_seq_len, batch_size, len(voc_sizes)),
dtype=torch.long)
for f, (voc_size, pad_idx) in enumerate(zip(voc_sizes, pad_idxs)):
for b, len_ in enumerate(lengths):
inps[:len_, b, f] = torch.randint(0, voc_size-1, (len_,))
inps[len_:, b, f] = pad_idx
return inps

@classmethod
def expected_shape(cls, params, init_case):
wvs = init_case["word_vec_size"]
fvs = init_case["feat_vec_size"]
nf = len(init_case["feat_vocab_sizes"])
size = wvs
if init_case["feat_merge"] not in {"sum", "mlp"}:
size += nf * fvs
return params["max_seq_len"], params["batch_size"], size

def test_embeddings_forward_shape(self):
for params, init_case in itertools.product(self.PARAMS, self.cases()):
emb = Embeddings(**init_case)
dummy_in = self.dummy_inputs(params, init_case)
res = emb(dummy_in)
expected_shape = self.expected_shape(params, init_case)
self.assertEqual(res.shape, expected_shape, init_case.__str__())

def test_embeddings_trainable_params(self):
for params, init_case in itertools.product(self.PARAMS,
self.cases()):
emb = Embeddings(**init_case)
trainable_params = {n: p for n, p in emb.named_parameters()
if p.requires_grad}
# first check there's nothing unexpectedly not trainable
for key in emb.state_dict():
if key not in trainable_params:
if key.endswith("emb_luts.0.weight") and \
init_case["fix_word_vecs"]:
# ok: word embeddings shouldn't be trainable
# if word vecs are fixed
continue
if key.endswith(".pe.pe"):
# ok: positional encodings shouldn't be trainable
assert init_case["position_encoding"]
continue
else:
self.fail("Param {:s} is unexpectedly not "
"trainable.".format(key))
# then check nothing unexpectedly trainable
if init_case["fix_word_vecs"]:
self.assertFalse(
any(trainable_param.endswith("emb_luts.0.weight")
for trainable_param in trainable_params),
"Word embedding is trainable but word vecs are fixed.")
if init_case["position_encoding"]:
self.assertFalse(
any(trainable_p.endswith(".pe.pe")
for trainable_p in trainable_params),
"Positional encoding is trainable.")

def test_embeddings_trainable_params_update(self):
for params, init_case in itertools.product(self.PARAMS, self.cases()):
emb = Embeddings(**init_case)
trainable_params = {n: p for n, p in emb.named_parameters()
if p.requires_grad}
if len(trainable_params) > 0:
old_weights = deepcopy(trainable_params)
dummy_in = self.dummy_inputs(params, init_case)
res = emb(dummy_in)
pretend_loss = res.sum()
pretend_loss.backward()
dummy_optim = torch.optim.SGD(trainable_params.values(), 1)
dummy_optim.step()
for param_name in old_weights.keys():
self.assertTrue(
trainable_params[param_name]
.ne(old_weights[param_name]).any(),
param_name + " " + init_case.__str__())