Skip to content

Commit

Permalink
Merge pull request cgpotts#80 from cgpotts/spring-2021-prep
Browse files Browse the repository at this point in the history
Spring 2021 prep
  • Loading branch information
cgpotts authored Mar 24, 2021
2 parents efc8ced + e59ad24 commit c95760c
Show file tree
Hide file tree
Showing 70 changed files with 7,584 additions and 5,571 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@ nli-data/*
nlidata/*
rel_ext_data*
*_solved.ipynb
.DS_Store
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Code for [the Stanford course](http://web.stanford.edu/class/cs224u/).

Fall 2020
Spring 2021

# Instructors

Expand Down Expand Up @@ -35,9 +35,9 @@ A generic optimization class (`torch_model_base.py`) and subclasses for GloVe, A
Reference implementations for the `torch_*.py` models, designed to reveal more about how the optimization process works.


## `vsm_*` and `hw_wordsim.ipynb`
## `vsm_*` and `hw_wordrelatedness.ipynb`

A until on vector space models of meaning, covering traditional methods like PMI and LSA as well as newer methods like Autoencoders and GloVe. `vsm.py` provides a lot of the core functionality, and `torch_glove.py` and `torch_autoencoder.py` are the learned models that we cover. `vsm_03_retroffiting.ipynb` is an extension that uses `retrofitting.py`.
A until on vector space models of meaning, covering traditional methods like PMI and LSA as well as newer methods like Autoencoders and GloVe. `vsm.py` provides a lot of the core functionality, and `torch_glove.py` and `torch_autoencoder.py` are the learned models that we cover. `vsm_03_retroffiting.ipynb` is an extension that uses `retrofitting.py`, and `vsm_04_contextualreps.ipynb` explores methods for deriving static representations from contextual models.


## `sst_*` and `hw_sst.ipynb`
Expand All @@ -60,9 +60,9 @@ A unit on Natural Language Inference. `nli.py` provides core interfaces to a var
A unit on grounded natural language generation, focused on generating context-dependent color descriptions using the [English Stanford Colors in Context dataset](https://cocolab.stanford.edu/datasets/colors.html).


## `contextualreps.ipynb`
## `finetuning.ipynb`

Using pretrained parameters from [Hugging Face](https://huggingface.co) and [AllenNLP](https://allennlp.org) for featurization and fine-tuning.
Using pretrained parameters from [Hugging Face](https://huggingface.co) for featurization and fine-tuning.


## `evaluation_*.ipynb` and `projects.md`
Expand Down
2 changes: 1 addition & 1 deletion colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import matplotlib.patches as mpatch

__author__ = "Christopher Potts"
__version__ = "CS224u, Stanford, Fall 2020"
__version__ = "CS224u, Stanford, Spring 2021"


TURN_BOUNDARY = " ### "
Expand Down
115 changes: 70 additions & 45 deletions colors_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"__author__ = \"Christopher Potts\"\n",
"__version__ = \"CS224u, Stanford, Fall 2020\""
"__version__ = \"CS224u, Stanford, Spring 2021\""
]
},
{
Expand Down Expand Up @@ -257,7 +257,7 @@
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAABLUlEQVR4nO3YMUrEUBRA0XyZSiutnC24EjvXajcrcQtOpZW2315kVMgQ5nJOmxTvweURMuacC5RcbT0ArE3U5IiaHFGTI2pydqcejjEu/tfInHP85b3nu/eL33VZluXp7fbXfR9uPhO7vnxc/7irS02OqMkRNTknv6m/u398Pdccqzke9luPwMZcanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9Tk7P7z8vGwP9ccsBqXmhxRkyNqcsacc+sZYFUuNTmiJkfU5IiaHFGTI2pyvgBwhhdAIEFGnQAAAABJRU5ErkJggg==\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAABLUlEQVR4nO3YMUrEUBRA0XyZSiutnC24EjvXajcrcQtOpZW2315kVMgQ5nJOmxTvweURMuacC5RcbT0ArE3U5IiaHFGTI2pydqcejjEu/tfInHP85b3nu/eL33VZluXp7fbXfR9uPhO7vnxc/7irS02OqMkRNTknv6m/u398Pdccqzke9luPwMZcanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9Tk7P7z8vGwP9ccsBqXmhxRkyNqcsacc+sZYFUuNTmiJkfU5IiaHFGTI2pyvgBwhhdAIEFGnQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 216x72 with 3 Axes>"
]
Expand Down Expand Up @@ -295,7 +295,7 @@
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAABFUlEQVR4nO3YsW1CMRRAUX+UCipShRXYhCqzpsomrJBUSRVaswAiFEiIq3Nau3hPunLhZc45oGT16AHg3kRNjqjJETU5oibn5drhfnN6+q+R4996ueXe7vD99LuOMcbX59u/+368/iZ2ff/ZXtzVS02OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IianGXO+egZ4K681OSImhxRkyNqckRNjqjJOQNHYRKDRd/3AwAAAABJRU5ErkJggg==\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAABFUlEQVR4nO3YsW1CMRRAUX+UCipShRXYhCqzpsomrJBUSRVaswAiFEiIq3Nau3hPunLhZc45oGT16AHg3kRNjqjJETU5oibn5drhfnN6+q+R4996ueXe7vD99LuOMcbX59u/+368/iZ2ff/ZXtzVS02OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IianGXO+egZ4K681OSImhxRkyNqckRNjqjJOQNHYRKDRd/3AwAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 216x72 with 3 Axes>"
]
Expand Down Expand Up @@ -338,7 +338,7 @@
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAABLElEQVR4nO3YsU3DUBRAUX+UCiqoyApMQsesdJmEFUgFFbSfBVBwYcnK5ZzWLt6Trp4sjznnAiU3ew8AWxM1OaImR9TkiJqcw6WHrw+fV/9r5OXjfqx57+nu++p3XZZlefu6/XPfMUZi1znnr7u61OSImhxRk3Pxm5r/4fH5fe8RVjmfjqvec6nJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRk3PYewD2dz4d9x5hUy41OaImR9TkjDnn3jPAplxqckRNjqjJETU5oiZH1OT8AK1HF0DPcEkgAAAAAElFTkSuQmCC\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAABLElEQVR4nO3YsU3DUBRAUX+UCiqoyApMQsesdJmEFUgFFbSfBVBwYcnK5ZzWLt6Trp4sjznnAiU3ew8AWxM1OaImR9TkiJqcw6WHrw+fV/9r5OXjfqx57+nu++p3XZZlefu6/XPfMUZi1znnr7u61OSImhxRk3Pxm5r/4fH5fe8RVjmfjqvec6nJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRk3PYewD2dz4d9x5hUy41OaImR9TkjDnn3jPAplxqckRNjqjJETU5oiZH1OT8AK1HF0DPcEkgAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 216x72 with 3 Axes>"
]
Expand Down Expand Up @@ -580,7 +580,7 @@
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAABLUlEQVR4nO3YwUnEUBRA0XyZbrQE3Qp24SytaJZOF4JbLUHr+TYgYxaBMNdztsniPbg8Qsacc4GSm70HgK2JmhxRkyNqckRNzuHSw7vvx6v/NfJ1+z7WvPf59nD1uy7Lstw/ffy57xgjseuc89ddXWpyRE2OqMm5+E3N//Dy/Lr3CKuczsdV77nU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMk57D0A+zudj3uPsCmXmhxRkyNqcsacc+8ZYFMuNTmiJkfU5IiaHFGTI2pyfgAdJBcf7IJsUgAAAABJRU5ErkJggg==\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAABLUlEQVR4nO3YwUnEUBRA0XyZbrQE3Qp24SytaJZOF4JbLUHr+TYgYxaBMNdztsniPbg8Qsacc4GSm70HgK2JmhxRkyNqckRNzuHSw7vvx6v/NfJ1+z7WvPf59nD1uy7Lstw/ffy57xgjseuc89ddXWpyRE2OqMm5+E3N//Dy/Lr3CKuczsdV77nU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMkRNTmiJkfU5IiaHFGTI2pyRE2OqMk57D0A+zudj3uPsCmXmhxRkyNqcsacc+8ZYFMuNTmiJkfU5IiaHFGTI2pyfgAdJBcf7IJsUgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 216x72 with 3 Axes>"
]
Expand Down Expand Up @@ -619,7 +619,7 @@
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAABKklEQVR4nO3YQUrDUBRA0XzpbuyO7FTXUHAsuAad1h3Z9Xw3IDWDQOj1nGkyeA8uj5Ax51yg5GHvAWBroiZH1OSImhxRk3O49fDt8+Xuf428Pn+MNe9dTl93v+uyLMvp8vTnvmOMxK5zzl93danJETU5oibn5jc1/8P79+PeI6xyPl5XvedSkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oibnsPcA7O98vO49wqZcanJETY6oyRlzzr1ngE251OSImhxRkyNqckRNjqjJ+QHLEhcAkintbgAAAABJRU5ErkJggg==\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAABKklEQVR4nO3YQUrDUBRA0XzpbuyO7FTXUHAsuAad1h3Z9Xw3IDWDQOj1nGkyeA8uj5Ax51yg5GHvAWBroiZH1OSImhxRk3O49fDt8+Xuf428Pn+MNe9dTl93v+uyLMvp8vTnvmOMxK5zzl93danJETU5oibn5jc1/8P79+PeI6xyPl5XvedSkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oiZH1OSImhxRkyNqckRNjqjJETU5oibnsPcA7O98vO49wqZcanJETY6oyRlzzr1ngE251OSImhxRkyNqckRNjqjJ+QHLEhcAkintbgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 216x72 with 3 Axes>"
]
Expand Down Expand Up @@ -658,7 +658,7 @@
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAABK0lEQVR4nO3YwUnEUBRA0XyZCnRlH4rDtGC9tiAO2ocrbeHbgIxZBMJcz9kmi/fg8ggZc84FSm72HgC2JmpyRE2OqMkRNTmHSw+/7z6v/tfI7df9WPPey+PH1e+6LMvy/P7w575jjMSuc85fd3WpyRE1OaIm5+I3Nf/D29Pr3iOscjyfVr3nUpMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaIm57D3AOzveD7tPcKmXGpyRE2OqMkZc869Z4BNudTkiJocUZMjanJETY6oyfkBPhUWwkgMDc4AAAAASUVORK5CYII=\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALUAAABECAYAAADHnXQVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAABK0lEQVR4nO3YwUnEUBRA0XyZCnRlH4rDtGC9tiAO2ocrbeHbgIxZBMJcz9kmi/fg8ggZc84FSm72HgC2JmpyRE2OqMkRNTmHSw+/7z6v/tfI7df9WPPey+PH1e+6LMvy/P7w575jjMSuc85fd3WpyRE1OaIm5+I3Nf/D29Pr3iOscjyfVr3nUpMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaImR9TkiJocUZMjanJETY6oyRE1OaIm57D3AOzveD7tPcKmXGpyRE2OqMkZc869Z4BNudTkiJocUZMjanJETY6oyfkBPhUWwkgMDc4AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 216x72 with 3 Axes>"
]
Expand Down Expand Up @@ -945,7 +945,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Finished epoch 200 of 200; error is 0.26593604683876043"
"Finished epoch 200 of 200; error is 0.26593607664108276"
]
}
],
Expand Down Expand Up @@ -1106,7 +1106,9 @@
}
],
"source": [
"toy_mod.corpus_bleu(toy_color_seqs_test, toy_word_seqs_test)"
"bleu_score, predicted_texts = toy_mod.corpus_bleu(toy_color_seqs_test, toy_word_seqs_test)\n",
"\n",
"bleu_score"
]
},
{
Expand Down Expand Up @@ -1216,8 +1218,8 @@
"output_type": "stream",
"text": [
"{'<s>': 1.0, '</s>': 0.0, 'A': 0.0, 'B': 0.0, '$UNK': 0.0}\n",
"{'<s>': 0.00018379976, '</s>': 0.00022975517, 'A': 0.9946944, 'B': 0.004481194, '$UNK': 0.00041091075}\n",
"{'<s>': 0.0010102483, '</s>': 0.023374218, 'A': 0.0016727167, 'B': 0.9730926, '$UNK': 0.0008501807}\n",
"{'<s>': 0.00018379976, '</s>': 0.00022975517, 'A': 0.9946944, 'B': 0.004481194, '$UNK': 0.00041091096}\n",
"{'<s>': 0.0010102493, '</s>': 0.02337423, 'A': 0.0016727175, 'B': 0.9730926, '$UNK': 0.00085018104}\n",
"{'<s>': 0.0046478347, '</s>': 0.9801214, 'A': 0.01115099, 'B': 0.0027307996, '$UNK': 0.001349019}\n"
]
}
Expand Down Expand Up @@ -1250,7 +1252,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Finished epoch 200 of 200; error is 0.42917731404304504"
"/Applications/anaconda3/envs/nlu/lib/python3.8/site-packages/numpy/core/_asarray.py:83: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray\n",
" return array(a, dtype, copy=False, order=order)\n",
"Finished epoch 200 of 200; error is 0.45002618432044983"
]
},
{
Expand Down Expand Up @@ -1427,15 +1431,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 60.62587070465088"
"Stopping after epoch 17. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 57.85624027252197"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4min 17s, sys: 8.47 s, total: 4min 26s\n",
"Wall time: 1min 6s\n"
"CPU times: user 1min 50s, sys: 5.29 s, total: 1min 55s\n",
"Wall time: 56.9 s\n"
]
}
],
Expand All @@ -1454,28 +1458,49 @@
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"dev_mod_eval = dev_mod.evaluate(dev_cols_test, dev_word_seqs_test)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/cgpotts/Documents/teaching/2019-2020/xcs224u/cs224u/torch_color_describer.py:678: RuntimeWarning: divide by zero encountered in power\n",
" perp = [np.prod(s)**(-1/len(s)) for s in scores]\n"
]
},
"data": {
"text/plain": [
"0.367117765620501"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dev_mod_eval['listener_accuracy']"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'listener_accuracy': 0.32450331125827814, 'corpus_bleu': 0.05031672905269219}"
"0.05830693924560899"
]
},
"execution_count": 52,
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dev_mod.evaluate(dev_cols_test, dev_word_seqs_test)"
"dev_mod_eval['corpus_bleu']"
]
},
{
Expand Down Expand Up @@ -1517,7 +1542,7 @@
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -1542,7 +1567,7 @@
},
{
"cell_type": "code",
"execution_count": 54,
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -1567,7 +1592,7 @@
},
{
"cell_type": "code",
"execution_count": 55,
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -1601,7 +1626,7 @@
},
{
"cell_type": "code",
"execution_count": 56,
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -1613,14 +1638,14 @@
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Finished epoch 1000 of 1000; error is 0.10807410627603531"
"Finished epoch 1000 of 1000; error is 0.12768782675266266"
]
}
],
Expand All @@ -1630,7 +1655,7 @@
},
{
"cell_type": "code",
"execution_count": 58,
"execution_count": 60,
"metadata": {},
"outputs": [
{
Expand All @@ -1639,7 +1664,7 @@
"1.0"
]
},
"execution_count": 58,
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -1673,7 +1698,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -1711,7 +1736,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -1747,7 +1772,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -1759,14 +1784,14 @@
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 64,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Finished epoch 1000 of 1000; error is 0.13370060920715332"
"Finished epoch 1000 of 1000; error is 0.1362161487340927"
]
}
],
Expand All @@ -1776,22 +1801,22 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'listener_accuracy': 1.0, 'corpus_bleu': 1.0}"
"1.0"
]
},
"execution_count": 63,
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mod_deep.evaluate(toy_color_seqs_test, toy_word_seqs_test)"
"mod_deep.listener_accuracy(toy_color_seqs_test, toy_word_seqs_test)"
]
}
],
Expand All @@ -1811,9 +1836,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
Loading

0 comments on commit c95760c

Please sign in to comment.