forked from cgpotts/cs224u
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_sst.py
93 lines (75 loc) · 2.58 KB
/
test_sst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from collections import Counter
import os
from sklearn.linear_model import LogisticRegression
import sst
from torch_rnn_classifier import TorchRNNClassifier
import pytest
import utils
__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Fall 2020"
utils.fix_random_seeds()
sst_home = os.path.join('data', 'trees')
@pytest.mark.parametrize("reader, count", [
[sst.train_reader(sst_home, class_func=None), 8544],
[sst.train_reader(sst_home, class_func=sst.binary_class_func), 6920],
[sst.train_reader(sst_home, class_func=sst.ternary_class_func), 8544],
[sst.dev_reader(sst_home, class_func=None), 1101],
[sst.dev_reader(sst_home, class_func=sst.binary_class_func), 872],
[sst.dev_reader(sst_home, class_func=sst.ternary_class_func), 1101],
])
def test_readers(reader, count):
result = len(list(reader))
assert result == count
def test_reader_labeling():
tree, label = next(sst.train_reader(sst_home, class_func=sst.ternary_class_func))
for subtree in tree.subtrees():
assert subtree.label() in {'negative', 'neutral', 'positive'}
def test_build_dataset_vectorizing():
phi = lambda tree: Counter(tree.leaves())
class_func = None
reader = sst.train_reader
dataset = sst.build_dataset(
sst_home,
reader,
phi,
class_func,
vectorizer=None,
vectorize=True)
assert len(dataset['X']) == len(list(reader(sst_home)))
assert len(dataset['y']) == len(dataset['X'])
assert len(dataset['raw_examples']) == len(dataset['X'])
def test_build_dataset_not_vectorizing():
phi = lambda tree: tree
class_func = None
reader = sst.train_reader
dataset = sst.build_dataset(
sst_home,
reader,
phi,
class_func,
vectorizer=None,
vectorize=False)
assert len(dataset['X']) == len(list(reader(sst_home)))
assert dataset['X'] == dataset['raw_examples']
assert len(dataset['y']) == len(dataset['X'])
def test_build_rnn_dataset():
X, y = sst.build_rnn_dataset(
sst_home, sst.train_reader, class_func=sst.binary_class_func)
assert len(X) == 6920
assert len(y) == 6920
@pytest.mark.parametrize("assess_reader", [
None,
sst.dev_reader
])
def test_experiment(assess_reader):
def fit_maxent(X, y):
mod = LogisticRegression(solver='liblinear', multi_class='auto')
mod.fit(X, y)
return mod
sst.experiment(
sst_home,
train_reader=sst.train_reader,
phi=lambda x: {"$UNK": 1},
train_func=fit_maxent,
assess_reader=assess_reader,
random_state=42)