Skip to content

Commit

Permalink
Additional incremental updates
Browse files Browse the repository at this point in the history
cgpotts committed Mar 22, 2022
1 parent d205976 commit 56c29cc
Showing 11 changed files with 695 additions and 316 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -67,3 +67,6 @@ nlidata/*
rel_ext_data*
*_solved.ipynb
.DS_Store

ColBERT*
experiments*
48 changes: 23 additions & 25 deletions evaluation_methods.ipynb

Large diffs are not rendered by default.

22 changes: 19 additions & 3 deletions evaluation_metrics.ipynb

Large diffs are not rendered by default.

214 changes: 122 additions & 92 deletions feature_attribution.ipynb

Large diffs are not rendered by default.

117 changes: 58 additions & 59 deletions finetuning.ipynb
Original file line number Diff line number Diff line change
@@ -67,6 +67,7 @@
"from sklearn.metrics import classification_report\n",
"import torch\n",
"import torch.nn as nn\n",
"import transformers\n",
"from transformers import BertModel, BertTokenizer\n",
"\n",
"from torch_shallow_neural_classifier import TorchShallowNeuralClassifier\n",
@@ -109,9 +110,7 @@
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"logger = logging.getLogger()\n",
"logger.level = logging.ERROR"
"transformers.logging.set_verbosity_error()"
]
},
{
@@ -213,7 +212,7 @@
"dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])"
]
},
"execution_count": 13,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@@ -241,7 +240,7 @@
" [101, 15035, 3520, 156, 14787, 13327, 4455, 28026, 1116, 102, 0, 0]]"
]
},
"execution_count": 14,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@@ -270,7 +269,7 @@
"[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]"
]
},
"execution_count": 15,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -317,7 +316,7 @@
"torch.Size([2, 768])"
]
},
"execution_count": 17,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
@@ -346,7 +345,7 @@
"torch.Size([2, 12, 768])"
]
},
"execution_count": 18,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -467,8 +466,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3h 3min 16s, sys: 1h 9min 9s, total: 4h 12min 26s\n",
"Wall time: 35min 24s\n"
"CPU times: user 32min 44s, sys: 52.8 s, total: 33min 37s\n",
"Wall time: 8min 24s\n"
]
}
],
@@ -485,8 +484,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 20min 28s, sys: 6min 30s, total: 26min 59s\n",
"Wall time: 4min 7s\n"
"CPU times: user 4min 14s, sys: 7.2 s, total: 4min 22s\n",
"Wall time: 1min 5s\n"
]
}
],
@@ -521,15 +520,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 23. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.422645628452301"
"Stopping after epoch 45. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.156181752681732"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 9.19 s, sys: 355 ms, total: 9.55 s\n",
"Wall time: 3.09 s\n"
"CPU times: user 21.3 s, sys: 2.56 s, total: 23.9 s\n",
"Wall time: 8.85 s\n"
]
}
],
@@ -557,13 +556,13 @@
"text": [
" precision recall f1-score support\n",
"\n",
" negative 0.732 0.741 0.736 428\n",
" neutral 0.397 0.135 0.202 229\n",
" positive 0.659 0.876 0.752 444\n",
" negative 0.696 0.787 0.739 428\n",
" neutral 0.342 0.279 0.308 229\n",
" positive 0.756 0.732 0.744 444\n",
"\n",
" accuracy 0.669 1101\n",
" macro avg 0.596 0.584 0.564 1101\n",
"weighted avg 0.633 0.669 0.632 1101\n",
" accuracy 0.659 1101\n",
" macro avg 0.598 0.600 0.597 1101\n",
"weighted avg 0.647 0.659 0.651 1101\n",
"\n"
]
}
@@ -604,7 +603,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 40. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.189640045166016"
"Stopping after epoch 39. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.242022633552551"
]
},
{
@@ -613,16 +612,16 @@
"text": [
" precision recall f1-score support\n",
"\n",
" negative 0.726 0.757 0.741 428\n",
" neutral 0.412 0.175 0.245 229\n",
" positive 0.688 0.865 0.766 444\n",
" negative 0.701 0.806 0.750 428\n",
" neutral 0.435 0.162 0.236 229\n",
" positive 0.714 0.842 0.773 444\n",
"\n",
" accuracy 0.679 1101\n",
" macro avg 0.609 0.599 0.584 1101\n",
"weighted avg 0.646 0.679 0.648 1101\n",
" accuracy 0.687 1101\n",
" macro avg 0.617 0.603 0.586 1101\n",
"weighted avg 0.651 0.687 0.652 1101\n",
"\n",
"CPU times: user 3h 25min 19s, sys: 1h 14min 58s, total: 4h 40min 17s\n",
"Wall time: 39min 33s\n"
"CPU times: user 38min 14s, sys: 1min 2s, total: 39min 17s\n",
"Wall time: 9min 49s\n"
]
}
],
@@ -675,7 +674,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 35. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.5038421787321568"
"Stopping after epoch 32. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.7171962857246399"
]
},
{
@@ -684,16 +683,16 @@
"text": [
" precision recall f1-score support\n",
"\n",
" negative 0.708 0.687 0.698 428\n",
" neutral 0.355 0.328 0.341 229\n",
" positive 0.726 0.777 0.751 444\n",
" negative 0.702 0.776 0.737 428\n",
" neutral 0.351 0.236 0.282 229\n",
" positive 0.747 0.797 0.771 444\n",
"\n",
" accuracy 0.649 1101\n",
" macro avg 0.597 0.597 0.596 1101\n",
"weighted avg 0.642 0.649 0.645 1101\n",
" accuracy 0.672 1101\n",
" macro avg 0.600 0.603 0.597 1101\n",
"weighted avg 0.647 0.672 0.656 1101\n",
"\n",
"CPU times: user 3h 26min 32s, sys: 1h 15min 19s, total: 4h 41min 51s\n",
"Wall time: 39min 59s\n"
"CPU times: user 38min 45s, sys: 1min 39s, total: 40min 24s\n",
"Wall time: 10min 6s\n"
]
}
],
@@ -870,27 +869,27 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Finished epoch 1 of 1; error is 96.940810058265926"
"Finished epoch 1 of 1; error is 184.64238105341792"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best params: {'eta': 0.0001, 'gradient_accumulation_steps': 8, 'hidden_dim': 300}\n",
"Best score: 0.586\n",
"Best params: {'eta': 5e-05, 'gradient_accumulation_steps': 4, 'hidden_dim': 200}\n",
"Best score: 0.587\n",
" precision recall f1-score support\n",
"\n",
" negative 0.715 0.808 0.759 428\n",
" neutral 0.700 0.031 0.059 229\n",
" positive 0.662 0.905 0.765 444\n",
" negative 0.686 0.930 0.790 428\n",
" neutral 0.514 0.079 0.136 229\n",
" positive 0.763 0.836 0.798 444\n",
"\n",
" accuracy 0.686 1101\n",
" macro avg 0.692 0.581 0.527 1101\n",
"weighted avg 0.691 0.686 0.616 1101\n",
" accuracy 0.715 1101\n",
" macro avg 0.655 0.615 0.575 1101\n",
"weighted avg 0.682 0.715 0.657 1101\n",
"\n",
"CPU times: user 1h 48min 23s, sys: 5min 23s, total: 1h 53min 47s\n",
"Wall time: 1h 55min 41s\n"
"CPU times: user 1h 27min 12s, sys: 11min 18s, total: 1h 38min 31s\n",
"Wall time: 1h 37min 44s\n"
]
}
],
@@ -961,7 +960,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 9. Validation score did not improve by tol=1e-05 for more than 5 epochs. Final error is 7.519404984079301"
"Stopping after epoch 9. Validation score did not improve by tol=1e-05 for more than 5 epochs. Final error is 11.503188711278199"
]
},
{
@@ -970,16 +969,16 @@
"text": [
" precision recall f1-score support\n",
"\n",
" negative 0.756 0.825 0.789 912\n",
" neutral 0.338 0.314 0.325 389\n",
" positive 0.821 0.771 0.795 909\n",
" negative 0.816 0.754 0.784 912\n",
" neutral 0.332 0.501 0.400 389\n",
" positive 0.881 0.756 0.813 909\n",
"\n",
" accuracy 0.713 2210\n",
" macro avg 0.638 0.636 0.636 2210\n",
"weighted avg 0.709 0.713 0.710 2210\n",
" accuracy 0.710 2210\n",
" macro avg 0.676 0.670 0.666 2210\n",
"weighted avg 0.758 0.710 0.728 2210\n",
"\n",
"CPU times: user 13min 7s, sys: 19 s, total: 13min 26s\n",
"Wall time: 13min 27s\n"
"CPU times: user 9min 54s, sys: 1min 22s, total: 11min 17s\n",
"Wall time: 11min 16s\n"
]
}
],
471 changes: 396 additions & 75 deletions nli_02_models.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -31,3 +31,4 @@ torchvision==0.11.1
transformers==4.17.0
datasets==2.0.0
spacy
gitpython
2 changes: 1 addition & 1 deletion setup.ipynb
Original file line number Diff line number Diff line change
@@ -97,7 +97,7 @@
"\n",
"We recommend that you download it, unzip it, and place it in the same directory as your local copy of this Github repository. If you decide to put it somewhere else, you'll need to adjust the paths given in the \"Set-up\" sections of essentially all the notebooks.\n",
"\n",
"We recommend you to check the `md5` checksum of the `data.tgz` afte the download. The current version (as of 8/22/2021), the checksum is `a447b2a81835707ad7882f8f881af79a`. If you see the different checksum, then ask this to the teaching staff."
"We recommend you to check the `md5` checksum of the `data.tgz` after the download. The current version (as of 8/22/2021), the checksum is `a447b2a81835707ad7882f8f881af79a`. If you see the different checksum, then ask this to the teaching staff."
]
},
{
94 changes: 47 additions & 47 deletions sst_03_neural_networks.ipynb
Original file line number Diff line number Diff line change
@@ -223,8 +223,8 @@
" macro avg 0.544 0.521 0.480 1101\n",
"weighted avg 0.571 0.611 0.555 1101\n",
"\n",
"CPU times: user 2.12 s, sys: 52.9 ms, total: 2.18 s\n",
"Wall time: 2.18 s\n"
"CPU times: user 2.48 s, sys: 75.5 ms, total: 2.56 s\n",
"Wall time: 2.5 s\n"
]
}
],
@@ -305,16 +305,16 @@
"text": [
" precision recall f1-score support\n",
"\n",
" negative 0.591 0.673 0.630 428\n",
" negative 0.593 0.673 0.630 428\n",
" neutral 0.423 0.048 0.086 229\n",
" positive 0.560 0.741 0.638 444\n",
" positive 0.560 0.743 0.639 444\n",
"\n",
" accuracy 0.570 1101\n",
" macro avg 0.525 0.487 0.451 1101\n",
"weighted avg 0.544 0.570 0.520 1101\n",
" accuracy 0.571 1101\n",
" macro avg 0.525 0.488 0.452 1101\n",
"weighted avg 0.544 0.571 0.521 1101\n",
"\n",
"CPU times: user 3.64 s, sys: 41 ms, total: 3.68 s\n",
"Wall time: 3.69 s\n"
"CPU times: user 4.62 s, sys: 340 ms, total: 4.96 s\n",
"Wall time: 4.4 s\n"
]
}
],
@@ -574,15 +574,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 58. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.2520811893045902"
"Stopping after epoch 58. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.2886183727532625"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6min 37s, sys: 27.9 s, total: 7min 5s\n",
"Wall time: 2min 52s\n"
"CPU times: user 38.9 s, sys: 24.6 s, total: 1min 3s\n",
"Wall time: 19.2 s\n"
]
}
],
@@ -610,13 +610,13 @@
"text": [
" precision recall f1-score support\n",
"\n",
" negative 0.589 0.565 0.577 428\n",
" neutral 0.250 0.249 0.249 229\n",
" positive 0.621 0.646 0.634 444\n",
" negative 0.575 0.614 0.594 428\n",
" neutral 0.230 0.223 0.226 229\n",
" positive 0.637 0.606 0.621 444\n",
"\n",
" accuracy 0.532 1101\n",
" macro avg 0.487 0.487 0.487 1101\n",
"weighted avg 0.531 0.532 0.532 1101\n",
" accuracy 0.530 1101\n",
" macro avg 0.481 0.481 0.481 1101\n",
"weighted avg 0.529 0.530 0.529 1101\n",
"\n"
]
}
@@ -684,15 +684,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 22. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.7556907385587692"
"Stopping after epoch 27. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.3226494677364826"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3min 7s, sys: 16.6 s, total: 3min 23s\n",
"Wall time: 1min 29s\n"
"CPU times: user 13.1 s, sys: 9.15 s, total: 22.2 s\n",
"Wall time: 5.63 s\n"
]
}
],
@@ -720,13 +720,13 @@
"text": [
" precision recall f1-score support\n",
"\n",
" negative 0.642 0.757 0.695 428\n",
" neutral 0.250 0.157 0.193 229\n",
" positive 0.695 0.707 0.701 444\n",
" negative 0.676 0.664 0.670 428\n",
" neutral 0.307 0.323 0.315 229\n",
" positive 0.700 0.694 0.697 444\n",
"\n",
" accuracy 0.612 1101\n",
" macro avg 0.529 0.540 0.529 1101\n",
"weighted avg 0.582 0.612 0.593 1101\n",
" accuracy 0.605 1101\n",
" macro avg 0.561 0.560 0.561 1101\n",
"weighted avg 0.609 0.605 0.607 1101\n",
"\n"
]
}
@@ -797,27 +797,27 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 16. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.7026695416478938"
"Stopping after epoch 14. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 2.7038347354674672"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best params: {'embed_dim': 100, 'eta': 0.001, 'hidden_dim': 100}\n",
"Best score: 0.547\n",
"Best params: {'embed_dim': 75, 'eta': 0.001, 'hidden_dim': 100}\n",
"Best score: 0.546\n",
" precision recall f1-score support\n",
"\n",
" negative 0.668 0.668 0.668 428\n",
" neutral 0.291 0.218 0.249 229\n",
" positive 0.667 0.752 0.707 444\n",
" negative 0.699 0.666 0.682 428\n",
" neutral 0.299 0.240 0.266 229\n",
" positive 0.662 0.759 0.707 444\n",
"\n",
" accuracy 0.609 1101\n",
" macro avg 0.542 0.546 0.541 1101\n",
"weighted avg 0.589 0.609 0.597 1101\n",
" accuracy 0.615 1101\n",
" macro avg 0.553 0.555 0.552 1101\n",
"weighted avg 0.601 0.615 0.606 1101\n",
"\n",
"CPU times: user 6h 7min 58s, sys: 22min 13s, total: 6h 30min 12s\n",
"Wall time: 3h 35min 2s\n"
"CPU times: user 39min 55s, sys: 39.9 s, total: 40min 35s\n",
"Wall time: 39min 53s\n"
]
}
],
@@ -982,26 +982,26 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 13. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.037477616686373956"
"Stopping after epoch 13. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.021834758925251663"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best params: {'embed_dim': 100, 'eta': 0.05}\n",
"Best params: {'embed_dim': 300, 'eta': 0.05}\n",
"Best score: 0.784\n",
" precision recall f1-score support\n",
"\n",
" negative 0.779 0.814 0.796 912\n",
" positive 0.804 0.768 0.786 909\n",
" negative 0.827 0.779 0.802 912\n",
" positive 0.790 0.836 0.812 909\n",
"\n",
" accuracy 0.791 1821\n",
" macro avg 0.791 0.791 0.791 1821\n",
"weighted avg 0.791 0.791 0.791 1821\n",
" accuracy 0.807 1821\n",
" macro avg 0.808 0.807 0.807 1821\n",
"weighted avg 0.808 0.807 0.807 1821\n",
"\n",
"CPU times: user 21min 22s, sys: 1min 9s, total: 22min 31s\n",
"Wall time: 12min 1s\n"
"CPU times: user 42min 50s, sys: 28min 13s, total: 1h 11min 3s\n",
"Wall time: 17min 48s\n"
]
}
],
2 changes: 1 addition & 1 deletion test/test_colors.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import utils

__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Spring 2021"
__version__ = "CS224u, Stanford, Spring 2022"


utils.fix_random_seeds()
37 changes: 24 additions & 13 deletions torch_model_base.py
Original file line number Diff line number Diff line change
@@ -326,19 +326,10 @@ def fit(self, *args):
dataset = self.build_dataset(*args)
dataloader = self._build_dataloader(dataset, shuffle=True)

# Graph:
if not self.warm_start or not hasattr(self, "model"):
self.model = self.build_graph()
# This device move has to happen before the optimizer is built:
# https://pytorch.org/docs/master/optim.html#constructing-it
self.model.to(self.device)
self.optimizer = self.build_optimizer()
self.errors = []
self.validation_scores = []
self.no_improvement_count = 0
self.best_error = np.inf
self.best_score = -np.inf
self.best_parameters = None
# Set up parameters needed to use the model. This is a separate
# function to support using pretrained models for prediction,
# where it might not be desirable to call `fit`.
self.initialize()

# Make sure the model is where we want it:
self.model.to(self.device)
@@ -410,6 +401,26 @@ def fit(self, *args):

return self

def initialize(self):
"""
Method called by `fit` to establish core attributes. To use a
pretrained model without calling `fit`, one can use this
method.
"""
if not self.warm_start or not hasattr(self, "model"):
self.model = self.build_graph()
# This device move has to happen before the optimizer is built:
# https://pytorch.org/docs/master/optim.html#constructing-it
self.model.to(self.device)
self.optimizer = self.build_optimizer()
self.errors = []
self.validation_scores = []
self.no_improvement_count = 0
self.best_error = np.inf
self.best_score = -np.inf
self.best_parameters = None

@staticmethod
def _build_validation_split(*args, validation_fraction=0.2):
"""

0 comments on commit 56c29cc

Please sign in to comment.