forked from cgpotts/cs224u
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_glove.py
116 lines (88 loc) · 2.81 KB
/
test_glove.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import pandas as pd
import pytest
import tempfile
import torch.nn as nn
import utils
from test_torch_model_base import PARAMS_WITH_TEST_VALUES as BASE_PARAMS
from torch_glove import TorchGloVe, simple_example
from np_glove import GloVe
__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Fall 2020"
utils.fix_random_seeds()
PARAMS_WITH_TEST_VALUES = [
["embed_dim", 20],
["alpha", 0.65],
["xmax", 75]]
PARAMS_WITH_TEST_VALUES += BASE_PARAMS
@pytest.fixture
def count_matrix():
return np.array([
[ 4., 4., 2., 0.],
[ 4., 61., 8., 18.],
[ 2., 8., 10., 0.],
[ 0., 18., 0., 5.]])
@pytest.mark.parametrize("pandas", [True, False])
def test_model(count_matrix, pandas):
X = count_matrix
if pandas:
X = pd.DataFrame(X)
glove = TorchGloVe()
G = glove.fit(X)
G_is_pandas = isinstance(G, pd.DataFrame)
assert G_is_pandas == pandas
def test_simple_example():
corr = simple_example()
assert corr > 0.43
@pytest.mark.parametrize("param, expected", PARAMS_WITH_TEST_VALUES)
def test_params(param, expected):
mod = TorchGloVe(**{param: expected})
result = getattr(mod, param)
assert result == expected
@pytest.mark.parametrize("param, expected", PARAMS_WITH_TEST_VALUES)
def test_simple_example_params(count_matrix, param, expected):
X = count_matrix
mod = TorchGloVe(**{param: expected})
G = mod.fit(X)
corr = mod.score(X)
if not (param == "max_iter" and expected == 0):
assert corr > 0.40
@pytest.mark.parametrize("param, expected", PARAMS_WITH_TEST_VALUES)
def test_parameter_setting(param, expected):
mod = TorchGloVe()
mod.set_params(**{param: expected})
result = getattr(mod, param)
assert result == expected
def test_build_dataset(count_matrix):
X = count_matrix
# We needn't do the actual calculation to test here:
weights = X
mod = TorchGloVe()
dataset = mod.build_dataset(X, weights)
result = next(iter(dataset))
assert len(result) == 3
@pytest.mark.parametrize("param", ["W", "C"])
def test_model_graph_embed_dim(count_matrix, param):
X = count_matrix
mod = TorchGloVe(max_iter=1)
mod.fit(X)
mod_attr_val = mod.embed_dim
graph_param = getattr(mod.model, param)
graph_attr_val = graph_param.shape[1]
assert mod_attr_val == graph_attr_val
def test_save_load(count_matrix):
X = count_matrix
mod = TorchGloVe(max_iter=2)
mod.fit(X)
with tempfile.NamedTemporaryFile(mode='wb') as f:
name = f.name
mod.to_pickle(name)
mod2 = TorchGloVe.from_pickle(name)
mod2.fit(X)
def test_np_glove(count_matrix):
"""
Just makes sure that this code will run; it doesn't check that
it is creating good models.
"""
mod = GloVe()
mod.fit(count_matrix)