Skip to content

Commit

Permalink
SST unit improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
cgpotts committed Mar 22, 2019
1 parent 5a4c0ed commit 2b0eb70
Show file tree
Hide file tree
Showing 2 changed files with 288 additions and 107 deletions.
90 changes: 83 additions & 7 deletions sst_02_hand_built_features.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
" 1. [Assessing BasicSGDClassifier](#Assessing-BasicSGDClassifier)\n",
" 1. [Comparison with the baselines from Socher et al. 2013](#Comparison-with-the-baselines-from-Socher-et-al.-2013)\n",
" 1. [A shallow neural network classifier](#A-shallow-neural-network-classifier)\n",
" 1. [A softmax classifier in PyTorch](#A-softmax-classifier-in-PyTorch)\n",
"1. [Hyperparameter search](#Hyperparameter-search)\n",
" 1. [utils.fit_classifier_with_crossvalidation](#utils.fit_classifier_with_crossvalidation)\n",
" 1. [Example using LogisticRegression](#Example-using-LogisticRegression)\n",
Expand Down Expand Up @@ -94,6 +95,7 @@
"from sklearn.linear_model import LogisticRegression\n",
"import scipy.stats\n",
"from np_sgd_classifier import BasicSGDClassifier\n",
"import torch.nn as nn\n",
"from torch_shallow_neural_classifier import TorchShallowNeuralClassifier\n",
"import sst\n",
"import utils"
Expand Down Expand Up @@ -725,6 +727,80 @@
"It looks like, with enough iterations (and perhaps some fiddling with the activation function and hidden dimensionality), this classifier would meet or exceed the baseline set up by `LogisticRegression`."
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"### A softmax classifier in PyTorch\n",
"\n",
"Our PyTorch modules should support easy modification. For example, to turn `TorchShallowNeuralClassifier` into a `TorchSoftmaxClassifier`, one need only write a new `define_graph` method:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"class TorchSoftmaxClassifier(TorchShallowNeuralClassifier):\n",
" \n",
" def define_graph(self):\n",
" return nn.Linear(self.input_dim, self.n_classes_)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"def fit_torch_softmax(X, y):\n",
" mod = TorchSoftmaxClassifier(max_iter=100)\n",
" mod.fit(X, y)\n",
" return mod"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Finished epoch 100 of 100; error is 0.08181965257972479"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.780\n",
" precision recall f1-score support\n",
"\n",
" negative 0.773 0.765 0.769 995\n",
" positive 0.786 0.794 0.790 1081\n",
"\n",
" micro avg 0.780 0.780 0.780 2076\n",
" macro avg 0.780 0.779 0.779 2076\n",
"weighted avg 0.780 0.780 0.780 2076\n",
"\n"
]
}
],
"source": [
"_ = sst.experiment(\n",
" SST_HOME,\n",
" unigrams_phi, \n",
" fit_torch_softmax, \n",
" class_func=sst.binary_class_func)"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -772,7 +848,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -818,7 +894,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 22,
"metadata": {
"slideshow": {
"slide_type": "-"
Expand Down Expand Up @@ -872,7 +948,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -887,7 +963,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 24,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -953,7 +1029,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 25,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -996,7 +1072,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -1009,7 +1085,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 27,
"metadata": {},
"outputs": [
{
Expand Down
Loading

0 comments on commit 2b0eb70

Please sign in to comment.