Skip to content

Commit

Permalink
Optional y arg to tree network fit methods to allow cross-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
cgpotts committed Feb 2, 2020
1 parent fd11554 commit 46ec94c
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 42 deletions.
5 changes: 5 additions & 0 deletions np_model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def fit(self, X, y):
y : list
The one-hot label vector.
Returns
----------
self
"""
y = self.prepare_output_data(y)
self.initialize_parameters()
Expand Down Expand Up @@ -75,6 +79,7 @@ def fit(self, X, y):
progress_bar(
"Finished epoch {} of {}; error is {}".format
(iteration, self.max_iter, error))
return self

@staticmethod
def get_error(predictions, labels):
Expand Down
41 changes: 36 additions & 5 deletions np_tree_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,25 @@ def __init__(self, vocab, embedding=None, embed_dim=50, **kwargs):
self.hidden_dim = self.embed_dim * 2

def fit(self, X, y=None):
y = [t.label() for t in X]
"""Fairly standard `fit` method except that, if `y=None`,
then the labels `y` are presumed to come from the root nodes
of the trees in `X`. We retain the option of giving them
as a separate argument for consistency with the other model
interfaces, and so that we can use sklearn cross-validation
methods with this class.
Parameters
----------
X : list of `nltk.Tree` instances
y : iterable of labels, or None
Returns
-------
self
"""
if y is None:
y = [t.label() for t in X]
return super(TreeNN, self).fit(X, y)

def initialize_parameters(self):
Expand Down Expand Up @@ -160,7 +178,7 @@ def set_params(self, **params):
self.hidden_dim = self.embed_dim * 2


def simple_example():
def simple_example(initial_embedding=False, separate_y=False):
from nltk.tree import Tree
import utils

Expand Down Expand Up @@ -190,13 +208,26 @@ def simple_example():

X_test = [Tree.fromstring(x) for x in test]

if initial_embedding:
import numpy as np
embedding = np.random.uniform(
low=-1.0, high=1.0, size=(len(vocab), 50))
else:
embedding = None

model = TreeNN(
vocab,
embed_dim=50,
hidden_dim=50,
max_iter=100)
max_iter=100,
embedding=embedding)

if not separate_y:
y = [t.label() for t in X_train]
else:
y = None

model.fit(X_train)
model.fit(X_train, y=y)

print("\nTest predictions:")

Expand All @@ -213,4 +244,4 @@ def simple_example():


if __name__ == '__main__':
simple_example()
simple_example(initial_embedding=False, separate_y=False)
19 changes: 19 additions & 0 deletions sst_03_neural_networks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,18 @@
"%time _ = tree_nn_glove.fit(X_tree_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Important: if you wish to cross-validate this model using scikit-learn methods, then you'll need to give the labels as a separate argument, as in \n",
"\n",
"```\n",
"y_tree_train = [t.label() for t in X_tree_train]\n",
"tree_nn_glove.fit(X_tree_train, y_tree_train)\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 41,
Expand Down Expand Up @@ -995,6 +1007,13 @@
"%time _ = torch_tree_nn_glove.fit(X_tree_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As with `TreeNN` above, you have the option of specifying the labels separately, and this is required if you are cross-validating the model using scikit-learn methods."
]
},
{
"cell_type": "code",
"execution_count": 45,
Expand Down
22 changes: 17 additions & 5 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,24 @@ def test_torch_autoencoder_simple_example():
assert mse < 0.0001


def test_np_tree_nn_simple_example():
np_tree_nn.simple_example()
@pytest.mark.parametrize("initial_embedding, separate_y", [
[True, True],
[True, False],
[False, True],
[False, False]
])
def test_np_tree_nn_simple_example(initial_embedding, separate_y):
np_tree_nn.simple_example(initial_embedding, separate_y)


@pytest.mark.parametrize("initial_embedding", [True, False])
def test_torch_tree_nn_simple_example(initial_embedding):
torch_tree_nn.simple_example(initial_embedding)
@pytest.mark.parametrize("initial_embedding, separate_y", [
[True, True],
[True, False],
[False, True],
[False, False]
])
def test_torch_tree_nn_simple_example(initial_embedding, separate_y):
torch_tree_nn.simple_example(initial_embedding, separate_y)


def test_torch_tree_nn_incremental(X_tree):
Expand Down Expand Up @@ -437,6 +448,7 @@ def test_torch_rnn_classifier_save_load(X_sequence):
mod2.predict(X_test)
mod2.fit(X, y)


def test_torch_tree_nn_save_load(X_tree):
X, vocab = X_tree
mod = torch_tree_nn.TorchTreeNN(
Expand Down
65 changes: 33 additions & 32 deletions torch_tree_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def _define_embedding(self, embedding):

def forward(self, tree):
"""Recursively interprets `tree`, applying a classifier layer
to the final representation. The label comes from the root
of the tree itself.
to the final representation.
Parameters
----------
Expand All @@ -45,7 +44,7 @@ def forward(self, tree):
"""
root = self.interpret(tree)
return self.classifier_layer(root), tree.label()
return self.classifier_layer(root)

def interpret(self, subtree):
# Terminal nodes are str:
Expand Down Expand Up @@ -85,13 +84,18 @@ def build_graph(self):
output_dim=self.n_classes_,
hidden_activation=self.hidden_activation)

def fit(self, X, **kwargs):
"""Fairly standard `fit` method except that the labels `y` are
presumed to come from the root nodes of the trees in `X`.
def fit(self, X, y=None, **kwargs):
"""Fairly standard `fit` method except that, if `y=None`,
then the labels `y` are presumed to come from the root nodes
of the trees in `X`. We retain the option of giving them
as a separate argument for consistency with the other model
interfaces, and so that we can use sklearn cross-validation
methods with this class.
Parameters
----------
X : list of `nltk.Tree` instances
y : iterable of labels, or None
kwargs : dict
For passing other parameters. If 'X_dev' is included,
then performance is monitored every 10 epochs; use
Expand All @@ -102,28 +106,35 @@ def fit(self, X, **kwargs):
self
"""
# Labels:
if y is None:
y = [t.label() for t in X]
self.classes_ = sorted(set(y))
self.n_classes_ = len(self.classes_)
self.class2index = dict(zip(self.classes_, range(self.n_classes_)))

# Incremental performance:
X_dev = kwargs.get('X_dev')
if X_dev is not None:
dev_iter = kwargs.get('dev_iter', 10)
# Data prep:
self.classes_ = self.get_classes(X)
self.n_classes_ = len(self.classes_)
self.class2index = dict(zip(self.classes_, range(self.n_classes_)))

# Model:
if not self.warm_start or not hasattr(self, "model"):
self.model = self.build_graph()
self.model.to(self.device)
self.model.train()

# Optimization:
loss = nn.CrossEntropyLoss()
optimizer = self.optimizer(self.model.parameters(), lr=self.eta)

# Train:
dataset = list(zip(X, y))
for iteration in range(1, self.max_iter+1):
epoch_error = 0.0
random.shuffle(X)
for tree in X:
pred, label = self.model.forward(tree)
random.shuffle(dataset)
for tree, label in dataset:
pred = self.model.forward(tree)
label = self.convert_label(label)
err = loss(pred, label)
epoch_error += err.item()
Expand All @@ -140,21 +151,6 @@ def fit(self, X, **kwargs):
iteration, self.max_iter, epoch_error/len(X)))
return self

@staticmethod
def get_classes(X):
"""Classes as given by the root nodes in `X`.
Parameters
----------
X : list of nltk.tree.Tree
Returns
-------
list
"""
return sorted({t.label() for t in X})

def convert_label(self, label):
"""Convert a class label to a format that PyTorch can handle.
Expand Down Expand Up @@ -186,7 +182,7 @@ def predict_proba(self, X):
with torch.no_grad():
preds = []
for tree in X:
pred, _ = self.model.forward(tree)
pred = self.model.forward(tree)
preds.append(pred.squeeze())
preds = torch.stack(preds)
return torch.softmax(preds, dim=1).numpy()
Expand All @@ -209,7 +205,7 @@ def predict(self, X):
return [self.classes_[i] for i in probs.argmax(axis=1)]


def simple_example(initial_embedding=False):
def simple_example(initial_embedding=False, separate_y=False):
from nltk.tree import Tree

train = [
Expand Down Expand Up @@ -249,7 +245,12 @@ def simple_example(initial_embedding=False):
max_iter=50,
embedding=embedding)

mod.fit(X_train)
if separate_y:
y = [t.label() for t in X_train]
else:
y = None

mod.fit(X_train, y=y)

print("\nTest predictions:")

Expand All @@ -266,4 +267,4 @@ def simple_example(initial_embedding=False):


if __name__ == '__main__':
simple_example()
simple_example(separate_y=True)

0 comments on commit 46ec94c

Please sign in to comment.