Skip to content

Commit

Permalink
Add long roberta and llama baselines for long docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jstremme authored Aug 2, 2024
1 parent 3048f77 commit 8637345
Show file tree
Hide file tree
Showing 26 changed files with 4,978 additions and 0 deletions.
1 change: 1 addition & 0 deletions llama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Code modified from: https://github.com/4AI/LS-LLaMA/tree/main
262 changes: 262 additions & 0 deletions llama/evaluate_models.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "5d7b1d18",
"metadata": {},
"source": [
"### Evaluate Models\n"
]
},
{
"cell_type": "markdown",
"id": "80b1b2c7",
"metadata": {},
"source": [
"##### Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "87dc70f1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/anaconda/envs/azureml_py310_sdkv2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"import pickle\n",
"import json\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import torch.nn.functional as F\n",
"from sklearn import metrics"
]
},
{
"cell_type": "markdown",
"id": "ed3988f7",
"metadata": {},
"source": [
"##### Evaluation Parameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2a1fc80e",
"metadata": {},
"outputs": [],
"source": [
"threshold = 0.5 # currently we don't maximize val f1 to find the threshold... need to grab scores for all the val sets if we do this\n",
"num_std = 1.96\n",
"num_bootstrap = 1000\n",
"line_width = 2\n",
"alpha = 0.2\n",
"font_size = 16\n",
"legend_size = 10\n",
"x_size = 10\n",
"y_size = 10"
]
},
{
"cell_type": "markdown",
"id": "1ac2d927-76e0-48ab-8777-48bc70206d07",
"metadata": {},
"source": [
"##### Initialize Score, Model, and Color Arrays"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "acc793b2-80d0-45ac-9a3a-15dcf8fb53fb",
"metadata": {},
"outputs": [],
"source": [
"# Define master lists of labels, scores, names, and colors\n",
"all_y_trues, all_y_scores, all_model_names, all_colors = [], [], [], []"
]
},
{
"cell_type": "markdown",
"id": "aef4eaf2-fff5-4f83-8ade-2367a2513aa8",
"metadata": {},
"source": [
"##### Load Fine-Tuned Torch LM Results"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "26d15d8e-81cc-4cd5-a2dc-765325cb5a55",
"metadata": {},
"outputs": [],
"source": [
"# # ls llama 3 8b\n",
"# with open(\"ls-Meta-Llama-3-8B-msp-v2-mdace-20_raw_labels.pkl\", \"rb\") as f:\n",
"# ls_llama_8b_last_labels = pickle.load(f)\n",
"# with open(\"ls-Meta-Llama-3-8B-msp-v2-mdace-20_scores.pkl\", \"rb\") as f:\n",
"# ls_llama_8b_last_scores = pickle.load(f)\n",
"\n",
"# ls_llama_8b_last_scores_transformed = torch.sigmoid(torch.tensor(ls_llama_8b_last_scores))\n",
"# all_model_names.append(\"LS Llama-3 8B (Last)\")\n",
"# all_y_trues.append(ls_llama_8b_last_labels)\n",
"# all_y_scores.append(ls_llama_8b_last_scores_transformed)\n",
"# all_colors.append('#ab20fd')\n",
"\n",
"# ls unllama 3 8b\n",
"with open(\"ls-unllama-Meta-Llama-3-8B-msp-v2-mdace-20_raw_labels.pkl\", \"rb\") as f:\n",
" ls_unllama_8b_max_labels = pickle.load(f)\n",
"with open(\"ls-unllama-Meta-Llama-3-8B-msp-v2-mdace-20_raw_scores.pkl\", \"rb\") as f:\n",
" ls_unllama_8b_max_scores = pickle.load(f)\n",
"\n",
"ls_unllama_8b_max_scores_transformed = torch.sigmoid(torch.tensor(ls_unllama_8b_max_scores)).numpy()\n",
"all_model_names.append(\"LS UnLlama-3 8B (Max)\")\n",
"all_y_trues.append(ls_unllama_8b_max_labels)\n",
"all_y_scores.append(ls_unllama_8b_max_scores_transformed)\n",
"\n",
"# BELT Max 5 segments\n",
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_128_max_5_labels.pkl\", \"rb\") as f:\n",
" belt_5_max_labels = pickle.load(f)\n",
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_128_max_5_scores.pkl\", \"rb\") as f:\n",
" belt_5_max_scores = pickle.load(f)\n",
"\n",
"all_model_names.append(\"BELT 128 step 5 seg (Max)\")\n",
"all_y_trues.append(belt_5_max_labels)\n",
"all_y_scores.append(belt_5_max_scores)\n",
"\n",
"# BELT Max 128 segments\n",
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_448_max_128_labels.pkl\", \"rb\") as f:\n",
" belt_128_max_labels = pickle.load(f)\n",
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_448_max_128_scores.pkl\", \"rb\") as f:\n",
" belt_128_max_scores = pickle.load(f)\n",
"\n",
"all_model_names.append(\"BELT 448 step 128 seg (Max)\")\n",
"all_y_trues.append(belt_128_max_labels)\n",
"all_y_scores.append(belt_128_max_scores)"
]
},
{
"cell_type": "markdown",
"id": "6f21cd42",
"metadata": {},
"source": [
"##### Print Performance for all Metrics for all Models"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2665e9ce-07c4-4196-9fe7-7f911123b8f9",
"metadata": {},
"outputs": [],
"source": [
"def print_mean_ci_of_metric_list(metric_list, metric_name, num_std):\n",
" mean_metric = np.mean(metric_list)\n",
" std_metric = np.std(metric_list)\n",
" metric_low = np.maximum(mean_metric - std_metric * num_std, 0)\n",
" metric_high = np.minimum(mean_metric + std_metric * num_std, 1)\n",
"\n",
" print(\n",
" f\"{metric_name}: {round(mean_metric, 3)} ([{round(metric_low, 3)} - {round(metric_high, 3)}] 95% CI)\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4d6ff4a8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Results for LS UnLlama-3 8B (Max)\n",
"\n",
"Micro Average Precision: 0.277 ([0.256 - 0.299] 95% CI)\n",
"Micro ROC AUC: 0.828 ([0.818 - 0.839] 95% CI)\n",
"\n",
"Results for BELT 128 step 5 seg (Max)\n",
"\n",
"Micro Average Precision: 0.707 ([0.698 - 0.716] 95% CI)\n",
"Micro ROC AUC: 0.942 ([0.94 - 0.944] 95% CI)\n",
"\n",
"Results for BELT 448 step 128 seg (Max)\n",
"\n",
"Micro Average Precision: 0.804 ([0.797 - 0.812] 95% CI)\n",
"Micro ROC AUC: 0.971 ([0.969 - 0.972] 95% CI)\n"
]
}
],
"source": [
"model2metric_df = {}\n",
"for y_trues, y_scores, name in zip(\n",
" all_y_trues, all_y_scores, all_model_names\n",
"):\n",
" \n",
" micro_aps, macro_aps, micro_roc_aucs, macro_roc_aucs = [], [], [], []\n",
" for i in range(num_bootstrap):\n",
" \n",
" # Sample N records with replacement where N is the total number of records\n",
" sample_indices = np.random.choice(len(y_trues), len(y_trues))\n",
" sample_labels = np.array(y_trues)[sample_indices]\n",
" sample_scores = np.array(y_scores)[sample_indices]\n",
" \n",
" micro_ap = metrics.average_precision_score(y_true=sample_labels, y_score=sample_scores, average='micro')\n",
" micro_aps.append(micro_ap)\n",
"\n",
" # macro_ap = metrics.average_precision_score(y_true=sample_labels, y_score=sample_scores, average='macro')\n",
" # macro_aps.append(macro_ap)\n",
"\n",
" micro_roc_auc = metrics.roc_auc_score(y_true=sample_labels, y_score=sample_scores, average='micro')\n",
" micro_roc_aucs.append(micro_roc_auc)\n",
"\n",
" # macro_roc_auc = metrics.roc_auc_score(y_true=sample_labels, y_score=sample_scores, average='macro')\n",
" # macro_roc_aucs.append(macro_roc_auc)\n",
" \n",
" metric_df = pd.DataFrame({\n",
" \"micro_aps\": micro_aps,\n",
" \"micro_roc_aucs\": micro_roc_aucs,\n",
" })\n",
" model2metric_df[name] = metric_df\n",
"\n",
" print(f\"\\nResults for {name}\\n\")\n",
" print_mean_ci_of_metric_list(micro_aps, metric_name=\"Micro Average Precision\", num_std=num_std)\n",
" print_mean_ci_of_metric_list(micro_roc_aucs, metric_name=\"Micro ROC AUC\", num_std=num_std)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10 - SDK v2",
"language": "python",
"name": "python310-sdkv2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit 8637345

Please sign in to comment.