diff --git a/hw_openqa.ipynb b/hw_openqa.ipynb index 7bbea16..e642365 100644 --- a/hw_openqa.ipynb +++ b/hw_openqa.ipynb @@ -2,9 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "id": "G0qZfMf4Yreh" - }, + "metadata": {}, "source": [ "# Few-shot OpenQA with ColBERT retrieval" ] @@ -21,9 +19,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "NqUPAnWuYrej" - }, + "metadata": {}, "outputs": [], "source": [ "__author__ = \"Christopher Potts and Omar Khattab\"\n", @@ -32,9 +28,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "R6TRA3gNYrek" - }, + "metadata": {}, "source": [ "## Contents\n", "\n", @@ -73,9 +67,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "MXqIujoUYrek" - }, + "metadata": {}, "source": [ "## Overview\n", "\n", @@ -119,27 +111,21 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "dKB9zXRBYrel" - }, + "metadata": {}, "source": [ "## Set-up" ] }, { "cell_type": "markdown", - "metadata": { - "id": "1JwiISlVYrel" - }, + "metadata": {}, "source": [ "### Google Colab set-up" ] }, { "cell_type": "markdown", - "metadata": { - "id": "lbnxdvg7Yrem" - }, + "metadata": {}, "source": [ "We have sought to make this notebook self-contained so that it can easily be run as a Google Colab. If you are running it in Colab, make sure to select a GPU instance. The notebook will run on a CPU-only instance or CPU-only machine, but it should be much faster with GPU support.\n", "\n", @@ -149,16 +135,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "P2TXnN9HYrem", - "outputId": "c87c8b28-c38f-4fbe-f9ae-6cd6dc0d4b44" - }, + "metadata": {}, "outputs": [], "source": [ - "!pip install torch==1.10.0\n", + "!pip install torch==1.10.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html\n", "\n", "!pip install ujson\n", "\n", @@ -173,9 +153,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "hlk-FsfZYrem" - }, + "metadata": {}, "source": [ "If you are indeed on a GPU machine, then run the following to ensure complete CUDA support:" ] @@ -183,26 +161,19 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Fl5wBNCxYrem", - "outputId": "aacd2379-f6b7-428b-fec7-2dff54c2e9fa" - }, + "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "if torch.cuda.is_available():\n", - " !pip install cupy-cuda111" + " !pip uninstall cupy-cuda11x -y\n", + " !pip install cupy-cuda113" ] }, { "cell_type": "markdown", - "metadata": { - "id": "Ej8kZeh6Yren" - }, + "metadata": {}, "source": [ "If the above doesn't work, it might be because you don't have CUDA version 11.1. Run " ] @@ -210,9 +181,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "PojRXsuPYren" - }, + "metadata": {}, "outputs": [], "source": [ "import torch\n", @@ -223,18 +192,14 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "VfZEOAL2Yren" - }, + "metadata": {}, "source": [ "and then install the corresponding `cupy-cuda`. See [this table](https://docs.cupy.dev/en/stable/install.html#installing-cupy-from-pypi) for details on which one to install for different scenarios." ] }, { "cell_type": "markdown", - "metadata": { - "id": "sFM54iO7Yren" - }, + "metadata": {}, "source": [ "### General set-up" ] @@ -242,9 +207,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "hL9AAtTzYren" - }, + "metadata": {}, "outputs": [], "source": [ "import collections\n", @@ -262,9 +225,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "vNXANb8fYreo" - }, + "metadata": {}, "source": [ "Try to set all the seeds for reproducibility (won't extend to GPT-3):" ] @@ -272,9 +233,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "MIvsYoIpYreo" - }, + "metadata": {}, "outputs": [], "source": [ "seed = 1\n", @@ -288,9 +247,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "oKp57bQZYrep" - }, + "metadata": {}, "source": [ "The following should install the version of [Faiss](https://github.com/facebookresearch/faiss) that will cooperate with your set-up:" ] @@ -298,13 +255,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "aGUj1O9wYrep", - "outputId": "ffd634ba-c30c-4d77-c1cf-31533e05a876" - }, + "metadata": {}, "outputs": [], "source": [ "import torch\n", @@ -317,18 +268,14 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "-mp1C-oyYreq" - }, + "metadata": {}, "source": [ "### Language model set-up" ] }, { "cell_type": "markdown", - "metadata": { - "id": "pEUY5P8GYreq" - }, + "metadata": {}, "source": [ "To use the GPT-3 API, install the OpenAI library:" ] @@ -336,14 +283,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 633 - }, - "id": "4FclKZiuYreq", - "outputId": "53c0e8ca-e5fb-4bb1-b8be-c963d7517b90" - }, + "metadata": {}, "outputs": [], "source": [ "!pip install openai" @@ -352,9 +292,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "5cbE328hYrer" - }, + "metadata": {}, "outputs": [], "source": [ "import openai\n", @@ -365,9 +303,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "TKQEIYGDYrer" - }, + "metadata": {}, "outputs": [], "source": [ "transformers.logging.set_verbosity_error()" @@ -375,9 +311,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "axUPfnNmYrer" - }, + "metadata": {}, "source": [ "### ColBERT set-up\n", "\n", @@ -389,13 +323,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "L5wP933JYrer", - "outputId": "1bd1441c-7750-4d22-9519-9f68ff3cce74" - }, + "metadata": {}, "outputs": [], "source": [ "!git clone -b cpu_inference https://github.com/stanford-futuredata/ColBERT.git" @@ -404,13 +332,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "6LWz4f-3Yres", - "outputId": "11bb2cc3-140e-46d0-d04f-a35f74c97e12" - }, + "metadata": {}, "outputs": [], "source": [ "import os\n", @@ -419,15 +341,12 @@ "\n", "from colbert.infra import Run, RunConfig, ColBERTConfig\n", "from colbert.data import Collection\n", - "from colbert.searcher import Searcher\n", - "from utility.utils.dpr import has_answer, DPR_normalize" + "from colbert.searcher import Searcher" ] }, { "cell_type": "markdown", - "metadata": { - "id": "U5ganiuNYres" - }, + "metadata": {}, "source": [ "## Language models\n", "\n", @@ -442,9 +361,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "Fca8-RXjYres" - }, + "metadata": {}, "source": [ "### Answerhood" ] @@ -452,9 +369,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "U6KepnjTYret" - }, + "metadata": {}, "outputs": [], "source": [ "def _find_generated_answer(tokens, newline=\"\\n\" ): \n", @@ -481,9 +396,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "klW12GkAYret" - }, + "metadata": {}, "source": [ "### Eleuther models from Hugging Face" ] @@ -491,82 +404,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 209, - "referenced_widgets": [ - "19d0a74b2d134747880402c4e6f0fad0", - "fcb0f73a2e314a8e935f0802f2f12f31", - "9905ea5f04f34c5badc797bd6fa07d3a", - "4a08f6612d9d4a3e98d00926e56415b0", - "8a1a0c82837e45709424db63bc86ec86", - "fb7cce9fecc441b1baf720d2dea6dbfc", - "7086129269a04514aa49001d7e712216", - "4742b73a39d2444d882aad5b47076911", - "9dfca6ee005a4222aa369280c4381a7b", - "7583f3ce599641d08a625337921e017f", - "e98a1ee4792240b9b7285556d2f3ecbd", - "47e4b314bc6c435293cb2e5707d49ae1", - "f208df7109cd475da4f6979e387af2fc", - "c94c4ffce4cf4f9c93546bbe106bf807", - "6ceb7c5f67e24c1490ff9d209b2e23dd", - "36a30e94d4dd426db17841370f46faf1", - "9c5dd297ecb9495aa82ca68bbec66e1c", - "142b1dfda9ae4444a1d577d1c66c215a", - "c83a2859c1e941bca4849872a56baa53", - "b3990b506419477b824530e11f8a42ef", - "446124444abf4c739ad28e593f61562b", - "d5a9ed85f6424d329727e3651094d394", - "fb95bf53f6ab4ee19b1efeb17c2ea749", - "a551f75352ac45b3845c61bbd5c3bb3c", - "1a8214cc9451415483ce785b997032e6", - "e131d608e26f44f89e1b155fa8042df9", - "97c7cc3aa1cb438facac717b44861c71", - "83693e21aeae47b0a5b0ee42eb72a3b9", - "03d8affdbcfb44c08ccdf60e888d9978", - "0e16502571d94ea286c1fd32d387e5ea", - "a093e8d14d6645daae23a16ecd01e423", - "6b11d0d12a3c42f88defea028bc27545", - "a5246564505c48ee98d7644deed5a255", - "7946163e5da84225a8eccd89122b14df", - "da906492fc8044dfb04a43afa8a7bbac", - "76ab7a53d0b64bf8a670dfe31bf72422", - "c0ab47e3313549cd806f5c10d791d291", - "6b962355e3eb47678013cfd469b57237", - "f9c217ef31284d95814342521904b45f", - "133bffae565a49e79a395b08f880b3db", - "38030650fe094db6b6436fdcf458a97e", - "1b98baadc95644f5b83cc6bfc7cd8c09", - "7efc61ef150b4287841c06cc291722c7", - "38a2caefa5cd4722be6f553a1d9f7e8e", - "06363edc1de54c15a0a00e104a08d07b", - "108c1fbcc56443ba96d57133cad11645", - "eb96d7ee6af840b2af48783ddad1ddb5", - "f1726b2d78b54c4aa411b7e4eecbf4eb", - "279a7470d04e4a34b1afdc7d4761c815", - "47a7719397c249389d3ebc5753ba2acf", - "da6c76d121fa4e3ca3d85bc0dcde6212", - "40db9d36656c46238c5ed87adf8cb4ec", - "e927e8b18836481cb41c7117c0dcf3c4", - "c99ffa80225b4a09ad4e7570b6c4e672", - "6c7c0e3d3a844558b12dcdb952e76920", - "121976cff01349459b2909033f05c576", - "a8f1b2d780d74e4382e14974537b1e0b", - "7211cffdbc0a4ae1a9f894a98cc65b03", - "577af7ffedd8430da0870d4a9f773056", - "5616a0f10f344b499c9b9e2f2d4f97b9", - "cdb7e47723fe45e6845a02ad85871d44", - "5d67fd950c8e4732a32250c3decb9aef", - "e0d37a2cda934241b195138780fbe067", - "9ef856a58be54b9fa8b5c610051c25cf", - "23bb67fe91d644469d5be8b442b2e3b6", - "264be22c995f407ab22fb4605601f7c3" - ] - }, - "id": "1-FkbTaUYret", - "outputId": "f42b7a86-c3bb-4164-92bf-de57520f2bdf" - }, + "metadata": {}, "outputs": [], "source": [ "# \"gpt-neo-125M\" \"gpt-neo-1.3B\" \"gpt-neo-2.7B\" \"gpt-j-6B\"\n", @@ -589,9 +427,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "rEo9YKwNYret" - }, + "metadata": {}, "outputs": [], "source": [ "def run_eleuther(prompts, temperature=0.1, top_p=0.95, **generate_kwargs): \n", @@ -675,9 +511,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "7fIfuTEGYreu" - }, + "metadata": {}, "outputs": [], "source": [ "eleuther_ex = run_eleuther([ \n", @@ -689,9 +523,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "oiN04KynYreu" - }, + "metadata": {}, "source": [ "### GPT-3" ] @@ -699,9 +531,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "5t5HTfRkYreu" - }, + "metadata": {}, "outputs": [], "source": [ "def run_gpt3(prompts, engine=\"text-curie-001\", temperature=0.1, top_p=0.95, **gpt3_kwargs):\n", @@ -789,13 +619,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "y3jbKTAbYrev", - "outputId": "6e560ddc-3762-48ce-8ba6-f8e23c0bda37" - }, + "metadata": {}, "outputs": [], "source": [ "gpt3_ex = run_gpt3([\n", @@ -807,9 +631,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "XGK3hCs9Yrev" - }, + "metadata": {}, "source": [ "## SQuAD\n", "\n", @@ -819,115 +641,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 295, - "referenced_widgets": [ - "fedae0a67e9a4f249d9d2ab527d973e2", - "280ab82e1c5d4dd8abc4bd15e58034af", - "8c59fe3638c344748d6eed8206e8975d", - "5f72f99cd4094dc292146e35a4d67767", - "35e94aa9d357405780025da79e3691af", - "87991a91694f456f9318dc5afccec2c9", - "32e27633a7054f17af4dca4ed5e43279", - "1acc184682164c08aee48278ae399814", - "368e5fc79b9e4ff9b4fe2e512947d677", - "72570da75be34d37b32be2dae93222ec", - "91859e90f7784956bdc1d40a984d2c5b", - "bb922ea5e97a44d8b0e94d9f76773167", - "32fbedf574754ccc85c4b59d83f743f4", - "439bd70a6fda44b993051827f6f072e3", - "b704aad9cc6e474f988e8966ca179da1", - "dcc8bd4d5b7242ee8c47586cb73c7c0a", - "5c33f1097177468aa401020bac71a5be", - "cef06e938024471cabfcfdad4411b961", - "830d66f3b3e64b6f8c86ba775539672e", - "d370b83b677d454ea1c81470671443fd", - "8639b41657e04b70ac4f3f3ca517a035", - "5071e184457e42bf91a84cc5ea827c19", - "88fe4f76341e461abf81f2bf7596913b", - "5c12675e10044c6db66daf30f4d4f948", - "38cc26500d7f45909637c46e8a676cb4", - "6bc69367538f4af683b8a21b3dc59268", - "2d637c205b844d189f5e4232cb1ffda2", - "6510644f7d0545e183053878c6d6c161", - "73cdc7a035d544618e66ed586d7b9649", - "7018804c8581472298d7c8f24beda197", - "62196b4464bc48a59fa20f1f28f0fe7b", - "c28e93ddbc77422484ef2de085794312", - "b55b8143e71c4a2cbca7d399fd179797", - "8c52264dbf2f4cca86e05c112423b536", - "d14af2ecb8664d26b0792eb47791741b", - "52cd2dfb427b49feb9546ae276438cc4", - "f007de44513a41bf95f9f8b861cdaa6f", - "f0d1f5455d934205b20d0b635b746f5d", - "ac576998bfc940e7af4490bae2ad934f", - "4b007c951e474e3db19aa16ca96c16a6", - "40c60aa79e644772944b4ffdfd6f7c5a", - "9fd4ae98811341e285ccf84fc7bb73bb", - "73b8aa0fe175415a9864d017c5f88682", - "c0f59c1508c64a02bdbff82ed569b800", - "d556c1920c6f493996a7ba6b0d5fb946", - "2974fef15c3841baba7035eb6d360156", - "0ac869bb35124b39be5bdad7fc78473f", - "a7e04c08f551443481414a0dd619006e", - "b40610e2ebb74029bddea0c818532604", - "c05c539d546447a5882191f1f1e676a9", - "86a5ca66daa64c4b888c375b0d014099", - "acfc9aa4af0b4ba2a397789fdc63d4cb", - "254c594f27134d9b9e18736b7f8d0d57", - "c3a64df8553f4deea779b13a44233546", - "ad4e6f09e7b342a59235965668fc34a6", - "483121c04c514094ba89858a569144d5", - "872202218d51454faa1211c35a96dc40", - "b58f664dd2644d43946aec35857d2b11", - "41f19619d7e347569818bf786673c468", - "3fb7649992084382b327cb61976124ea", - "ce46384c19e84fd99447ede08bbaf243", - "03b85fa378584208a26436b2c65c62a8", - "c760fb9137f44be789276235d2ec552b", - "dc4900955f81453f98a10ea9b74b983c", - "84da01f23908412b9daef0a8cbcf6337", - "7efeb5f13634479391cffa9a6f9bea49", - "06c3d71451a44147892cf77cbee7704e", - "f4047dc7f8fe40e3858efacfe3fa9db3", - "36a8a57e71b54bf297e8dda3043ff966", - "78f4d0554a3748328f99d69a22eaa482", - "90cb7850489648eab6264dd813719a9c", - "2e25e5700521472e83e6888067336682", - "ab24a9d6b7a84981ac9d335ff571f7e5", - "2fe2c9cf13de4b968b4f318b221aa7dc", - "88c02342d65b4b8f9ee1cd25e4ea15ac", - "8ae88195643b40bda346711154a9c7bd", - "1625a303748840aea84cd884c8453578", - "d833a94f93074a4085fb2de3583037f5", - "ae5e15d70bf64cde9d8f2caefa2c27f2", - "2badb7d424fa46f6b88b3e7acaf6ab48", - "42c3232aaad049ff92e95957128e2f2d", - "9c27bd5813724a2c835e62854b76c2b8", - "ed33408a0cad4a46a8366263bdf81bfa", - "5969f13e33594392ad1a9245313fd91f", - "3b55afcc50d44f5bbdc3b2c6d06bbc87", - "d84e03f9117448ef970b930aa77fc968", - "467d2cd91db246ffaba3af644a47ff1d", - "cb5b0f438ff7492e88d9fe4cc9184e9d", - "3538e755160e478f8cf1262e14f2c484", - "6b096d4bd8d4497c98176ccf9fe26a13", - "e7b24e1fa266446cb6bb61ec067fa412", - "3eccf8ba2e3f4ee4b3a90a1cd89e43b1", - "6af64542f9d547afb656f14a87e6ba03", - "62639c9b27be44ee83f9b054f76f651b", - "42b97283bd4d4cdc92fed717c2513919", - "f86f33b86afb468c8758174771d676c4", - "304a8303bd0044e6aa53335b6ab30d5a", - "d958f867d5eb40379753a14982302d08", - "76b9355cce89436892b959f00b1cf702" - ] - }, - "id": "YQ_do58EYrev", - "outputId": "93464cd2-7658-40ac-f411-90cc3b7ba698" - }, + "metadata": {}, "outputs": [], "source": [ "squad = load_dataset(\"squad\")" @@ -935,9 +649,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "VbCUz66GYrev" - }, + "metadata": {}, "source": [ "The following utility just reads a SQuAD split in as a list of `SquadExample` instances:" ] @@ -945,9 +657,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "B9-0hkxgYrew" - }, + "metadata": {}, "outputs": [], "source": [ "SquadExample = namedtuple(\"SquadExample\", \"id title context question answers\")" @@ -956,9 +666,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "g2gt0dkeYrew" - }, + "metadata": {}, "outputs": [], "source": [ "def get_squad_split(squad, split=\"validation\"):\n", @@ -987,9 +695,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "NatnOLsDYrew" - }, + "metadata": {}, "outputs": [], "source": [ "squad_dev = get_squad_split(squad)" @@ -998,13 +704,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "jb-ZrSzoYrew", - "outputId": "c3d9b481-2620-434a-cd7b-43c90fc2d096" - }, + "metadata": {}, "outputs": [], "source": [ "squad_dev[0]" @@ -1048,9 +748,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "DZpXMk-0Yrew" - }, + "metadata": {}, "source": [ "## Evaluation\n", "\n", @@ -1064,9 +762,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "nHHntDSSYrew" - }, + "metadata": {}, "outputs": [], "source": [ "def normalize_answer(s: str) -> str:\n", @@ -1128,9 +824,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "lJo6C7pgYrex" - }, + "metadata": {}, "source": [ "The following is our general evaluation function. We will make extensive use of it to evaluate different systems:" ] @@ -1138,9 +832,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "bJHSxnA6Yrex" - }, + "metadata": {}, "outputs": [], "source": [ "def evaluate(examples, prompts, gens):\n", @@ -1183,9 +875,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "fj170C1PYrex" - }, + "metadata": {}, "source": [ "Here is a highly simplified example to help make the logic behind `evaluate` clearer: " ] @@ -1193,13 +883,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0bgFXLK3Yrex", - "outputId": "e941a6f7-e326-4b26-e91e-8c0ccd11a0a0" - }, + "metadata": {}, "outputs": [], "source": [ "ex = namedtuple(\"SquadExample\", \"id title context question answers\")\n", @@ -1219,18 +903,14 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "MHZHve9NYrex" - }, + "metadata": {}, "source": [ "The bake-off uses `macro_f1` as the primary metric." ] }, { "cell_type": "markdown", - "metadata": { - "id": "_3LQv2lbYrex" - }, + "metadata": {}, "source": [ "## Open QA with no context\n", "\n", @@ -1240,9 +920,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "swF4V0ngYrex" - }, + "metadata": {}, "outputs": [], "source": [ "def evaluate_no_context(examples, gen_func=run_eleuther, batch_size=20):\n", @@ -1259,13 +937,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "aPuq5nzDYrex", - "outputId": "cea1edaf-728b-4037-abbc-b58a2d82bfbd" - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -1277,13 +949,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lWi1r1-oYrey", - "outputId": "9ffe19d9-ddb1-4a91-9950-06ae521a2a77" - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -1294,9 +960,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "csAimDMGYrey" - }, + "metadata": {}, "source": [ "## Few-shot QA\n", "\n", @@ -1345,9 +1009,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "lhsl9yvHYrey" - }, + "metadata": {}, "outputs": [], "source": [ "def build_few_shot_qa_prompt(ex, squad_train, n_context=2, joiner=\"\\n\\n\"):\n", @@ -1371,9 +1033,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "XzLghzI5Yrez" - }, + "metadata": {}, "source": [ "Here's the sort of output we get with `n_context=1`:" ] @@ -1381,13 +1041,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "VEuVae4xYrez", - "outputId": "1fc52cb8-1fa7-4d84-c911-4e3f2cafd682" - }, + "metadata": {}, "outputs": [], "source": [ "print(build_few_shot_qa_prompt(dev_exs[0], squad_train, n_context=1))" @@ -1396,9 +1050,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "VUwgM625Yrez" - }, + "metadata": {}, "outputs": [], "source": [ "def evaluate_few_shot_qa(examples, squad_train, gen_func=run_eleuther, batch_size=20, n_context=2):\n", @@ -1416,13 +1068,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "UelblRFCYrez", - "outputId": "0587385d-207a-47bb-d363-5b33c3a20920" - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -1434,10 +1080,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "-fdPsEbmYrez", - "outputId": "be6e22d2-7655-4c98-89f9-a0fe81d204d8" - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -1449,9 +1092,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "WfvVdsGrYre0" - }, + "metadata": {}, "source": [ "## ColBERT\n", "\n", @@ -1471,9 +1112,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "NFYxJPpuYre0" - }, + "metadata": {}, "source": [ "### ColBERT parameters" ] @@ -1481,13 +1120,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "5tnUU2UHYre0", - "outputId": "9d7431a0-7f69-4549-960f-56f8575a8e97" - }, + "metadata": {}, "outputs": [], "source": [ "if not os.path.exists(os.path.join(\"data\", \"openqa\", \"colbertv2.0.tar.gz\")):\n", @@ -1499,18 +1132,14 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "cjZ1hZJnYre0" - }, + "metadata": {}, "source": [ "If something went wrong with the above, you can just download the file https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/colbertv2.0.tar.gz, unarchive it, and move the resulting `colbertv2.0` directory into the `data/openqa` directory." ] }, { "cell_type": "markdown", - "metadata": { - "id": "QzRjL61eYre0" - }, + "metadata": {}, "source": [ "### ColBERT index" ] @@ -1518,9 +1147,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "jD0U5Outa9HU" - }, + "metadata": {}, "outputs": [], "source": [ "if not os.path.exists(os.path.join(index_home, \"cs224u.collection.2bits.tgz\")):\n", @@ -1530,9 +1157,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "61ergQcQYre0" - }, + "metadata": {}, "source": [ "If something went wrong with the above, download the file https://web.stanford.edu/class/cs224u/data/cs224u.collection.2bits.tgz, unarchive it, and move the resulting `cs224u.collection.2bits` directory into the `experiments/notebook/indexes` directory (which you will probably need to create)." ] @@ -1540,14 +1165,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 71 - }, - "id": "XtEGC6MyYre0", - "outputId": "93164bc8-b9da-4893-8b2f-d5eeca0403a3" - }, + "metadata": {}, "outputs": [], "source": [ "collection = os.path.join(index_home, \"cs224u.collection.2bits\", \"cs224u.collection.tsv\")\n", @@ -1560,9 +1178,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "BixmAYizYre0" - }, + "metadata": {}, "outputs": [], "source": [ "index_name = \"cs224u.collection.2bits\"" @@ -1578,13 +1194,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "N6wYpChzYre1", - "outputId": "1eb08be0-a66c-4414-d69b-32733d74631b" - }, + "metadata": {}, "outputs": [], "source": [ "with Run().context(RunConfig(experiment='notebook')):\n", @@ -1593,9 +1203,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "dex5KPUTYre1" - }, + "metadata": {}, "source": [ "### Search\n", "\n", @@ -1605,13 +1213,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "-v2jKaR8Yre1", - "outputId": "6ae9e5ee-8a1d-4d00-9925-1dbf47d21412" - }, + "metadata": {}, "outputs": [], "source": [ "query = \"linguistics\"\n", @@ -1628,9 +1230,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "w9n0_xwdYre1" - }, + "metadata": {}, "source": [ "### Retrieval evaluation\n", "\n", @@ -1640,9 +1240,16 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "fsNayHzdYre1" - }, + "metadata": {}, + "outputs": [], + "source": [ + "from utility.utils.dpr import has_answer, DPR_normalize" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ "def success_at_k(examples, k=20):\n", @@ -1664,9 +1271,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "k2U4v58dYre1" - }, + "metadata": {}, "source": [ "Here is Sucess@20 for the SQuAD dev set:" ] @@ -1674,13 +1279,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "J2oEmspeYre1", - "outputId": "bfff7ec2-26e4-47ef-a101-04fff7fd0621" - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -1695,9 +1294,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "teKobQM8Yre1" - }, + "metadata": {}, "source": [ "## Zero-shot OpenQA with ColBERT retrieval\n", "\n", @@ -1707,9 +1304,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "jbn4BHr4Yre1" - }, + "metadata": {}, "outputs": [], "source": [ "def build_zero_shot_openqa_prompt(question, passage, joiner=\"\\n\\n\"):\n", @@ -1726,9 +1321,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "Xjz7l_yQYre2" - }, + "metadata": {}, "outputs": [], "source": [ "def evaluate_zero_shot_openqa(examples, joiner=\"\\n\\n\", gen_func=run_eleuther, batch_size=20):\n", @@ -1749,13 +1342,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "21DfS0hHYre2", - "outputId": "b653e3e8-252c-47f8-8481-c190aca6c286" - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -1767,13 +1354,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "MNhMla71Yre2", - "outputId": "d6a6d52b-4af5-4519-cd7f-2298e859cc79" - }, + "metadata": {}, "outputs": [], "source": [ "%%time\n", @@ -1784,9 +1365,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "outblNTaYre2" - }, + "metadata": {}, "source": [ "## Homework questions\n", "\n", @@ -2002,9 +1581,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "w2mx3Z4HYre2" - }, + "metadata": {}, "source": [ "### Few-shot OpenQA [2 points]\n", "\n", @@ -2020,9 +1597,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "HUuZ5l3gYre2" - }, + "metadata": {}, "outputs": [], "source": [ "def build_few_shot_open_qa_prompt(question, passage, train_exs, joiner=\"\\n\\n\"):\n", @@ -2053,9 +1628,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "vgeNwTu4Yre2" - }, + "metadata": {}, "outputs": [], "source": [ "def test_build_few_shot_open_qa_prompt(func):\n", @@ -2086,13 +1659,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lK991-2AYre3", - "outputId": "b8134fd6-d477-41e4-9d70-cbed90f17a0c" - }, + "metadata": {}, "outputs": [], "source": [ "test_build_few_shot_open_qa_prompt(build_few_shot_open_qa_prompt)" @@ -2101,9 +1668,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "DRhoMeEGYre3" - }, + "metadata": {}, "outputs": [], "source": [ "def evaluate_few_shot_open_qa(\n", @@ -2158,38 +1723,15 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "beEyy_eOYre3", - "outputId": "22b49855-289f-43e8-f214-5d18238a00b7" - }, + "metadata": {}, "outputs": [], "source": [ "test_evaluator(evaluate_few_shot_open_qa)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "pXNI9BRNYre3", - "outputId": "2e428199-d5f7-4346-a80f-7a6934366efe" - }, - "outputs": [], - "source": [ - "\n" - ] - }, { "cell_type": "markdown", - "metadata": { - "id": "oXeQzplkYre3" - }, + "metadata": {}, "source": [ "### Answer scoring [2 points]\n", "\n", @@ -2215,9 +1757,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "m7PhfMNsYre3" - }, + "metadata": {}, "outputs": [], "source": [ "def get_passages_with_scores(question, k=5):\n", @@ -2254,9 +1794,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "-vqCNOtMYre4" - }, + "metadata": {}, "outputs": [], "source": [ "def test_get_passages_with_scores(func):\n", @@ -2281,13 +1819,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "kfsb4pyHYre4", - "outputId": "f13fc8e2-f21b-44ec-b16f-caa2855b8136" - }, + "metadata": {}, "outputs": [], "source": [ "test_get_passages_with_scores(get_passages_with_scores)" @@ -2296,9 +1828,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "SwQVdCb6Yre4" - }, + "metadata": {}, "outputs": [], "source": [ "def answer_scoring(passages, passage_probs, prompts, gen_func=run_eleuther):\n", @@ -2339,9 +1869,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "JD7d8ucgYre4" - }, + "metadata": {}, "outputs": [], "source": [ "def test_answer_scoring(func):\n", @@ -2379,13 +1907,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "QWwwDzvBYre4", - "outputId": "f591f713-3572-4f09-a550-08fd1595835f" - }, + "metadata": {}, "outputs": [], "source": [ "test_answer_scoring(answer_scoring)" @@ -2394,9 +1916,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "OUcdfU9SYre4" - }, + "metadata": {}, "outputs": [], "source": [ "def answer_scoring_demo(question):\n", @@ -2420,9 +1940,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "PmURWvoJYre4" - }, + "metadata": {}, "source": [ "### Your original system [3 points]\n", "\n", @@ -2468,9 +1986,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "D9pfDIrPYre4" - }, + "metadata": {}, "outputs": [], "source": [ "# PLEASE MAKE SURE TO INCLUDE THE FOLLOWING BETWEEN THE START AND STOP COMMENTS:\n", @@ -2495,9 +2011,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "zS3G0Ss7Yre4" - }, + "metadata": {}, "source": [ "## Bake-off [1 point]\n", "\n", @@ -2521,9 +2035,7 @@ }, { "cell_type": "markdown", - "metadata": { - "id": "zS3G0Ss7Yre4" - }, + "metadata": {}, "source": [ "If the above fails, you can just download https://web.stanford.edu/class/cs224u/data/cs224u-openqa-test-unlabeled.txt and place it in `data/openqa`.\n", "\n", @@ -2533,9 +2045,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "GyttYJxoYre4" - }, + "metadata": {}, "outputs": [], "source": [ "def create_bakeoff_submission():\n", @@ -2603,11 +2113,8 @@ "provenance": [], "toc_visible": true }, - "interpreter": { - "hash": "a99ac6d2deb03d0b7ced3594556c328848678d7cea021ae1b9990e15d3ad5c49" - }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "nlu", "language": "python", "name": "python3" }, @@ -2620,8 +2127,12 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" + "pygments_lexer": "ipython3" + }, + "vscode": { + "interpreter": { + "hash": "ba8df77661866e4839e48c9c8b84db99d258acfe5e4b196f9a4f0237d7077e14" + } }, "widgets": { "application/vnd.jupyter.widget-state+json": {