Skip to content

Commit

Permalink
Merge pull request cgpotts#74 from TianHuaBooks/featurize_func
Browse files Browse the repository at this point in the history
Add parameter featurize_func
  • Loading branch information
cgpotts authored Nov 16, 2020
2 parents e449004 + dec173a commit 2920b14
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def wordentail_experiment(
assess_data,
vector_func,
vector_combo_func,
model):
model,
featurize_func=word_entail_featurize,
):
"""Train and evaluation code for the word-level entailment task.
Parameters
Expand All @@ -48,6 +50,7 @@ def wordentail_experiment(
Any function for combining two vectors into a new vector
of fixed dimensionality.
model : class with `fit` and `predict` methods
featurize_func : function to return feature (X,y) with intended tensor
Prints
------
Expand All @@ -69,9 +72,9 @@ def wordentail_experiment(
between these experiments and the bake-off evaluation.
"""
X_train, y_train = word_entail_featurize(
X_train, y_train = featurize_func(
train_data, vector_func, vector_combo_func)
X_dev, y_dev = word_entail_featurize(
X_dev, y_dev = featurize_func(
assess_data, vector_func, vector_combo_func)
model.fit(X_train, y_train)
predictions = model.predict(X_dev)
Expand All @@ -87,7 +90,7 @@ def wordentail_experiment(
'vector_combo_func': vector_combo_func}


def bake_off_evaluation(experiment_results, test_data_filename=None):
def bake_off_evaluation(experiment_results, test_data_filename=None, featurize_func=word_entail_featurize):
"""Function for evaluating a trained model on the bake-off test set.
Parameters
Expand All @@ -98,6 +101,7 @@ def bake_off_evaluation(experiment_results, test_data_filename=None):
test_data_filename : str or None
Full path to the test data. If `None`, then we assume the file is
'data/nlidata/nli_wordentail_bakeoff_data-test.json'.
featurize_func : function to return feature (X,y) with intended tensor
Prints
------
Expand All @@ -110,7 +114,7 @@ def bake_off_evaluation(experiment_results, test_data_filename=None):
'data', 'nlidata', 'nli_wordentail_bakeoff_data-test.json')
with open(test_data_filename, encoding='utf8') as f:
wordentail_data = json.load(f)
X_test, y_test = word_entail_featurize(
X_test, y_test = featurize_func(
wordentail_data['test'],
vector_func=experiment_results['vector_func'],
vector_combo_func=experiment_results['vector_combo_func'])
Expand Down

0 comments on commit 2920b14

Please sign in to comment.