-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #26 from Optum/llama_roberta_baselines
Add long roberta and llama baselines for long docs
- Loading branch information
Showing
26 changed files
with
4,978 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Code modified from: https://github.com/4AI/LS-LLaMA/tree/main |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5d7b1d18", | ||
"metadata": {}, | ||
"source": [ | ||
"### Evaluate Models\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "80b1b2c7", | ||
"metadata": {}, | ||
"source": [ | ||
"##### Imports" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "87dc70f1", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/anaconda/envs/azureml_py310_sdkv2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
" from .autonotebook import tqdm as notebook_tqdm\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"import pickle\n", | ||
"import json\n", | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import torch.nn.functional as F\n", | ||
"from sklearn import metrics" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ed3988f7", | ||
"metadata": {}, | ||
"source": [ | ||
"##### Evaluation Parameters" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "2a1fc80e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"threshold = 0.5 # currently we don't maximize val f1 to find the threshold... need to grab scores for all the val sets if we do this\n", | ||
"num_std = 1.96\n", | ||
"num_bootstrap = 1000\n", | ||
"line_width = 2\n", | ||
"alpha = 0.2\n", | ||
"font_size = 16\n", | ||
"legend_size = 10\n", | ||
"x_size = 10\n", | ||
"y_size = 10" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1ac2d927-76e0-48ab-8777-48bc70206d07", | ||
"metadata": {}, | ||
"source": [ | ||
"##### Initialize Score, Model, and Color Arrays" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "acc793b2-80d0-45ac-9a3a-15dcf8fb53fb", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Define master lists of labels, scores, names, and colors\n", | ||
"all_y_trues, all_y_scores, all_model_names, all_colors = [], [], [], []" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "aef4eaf2-fff5-4f83-8ade-2367a2513aa8", | ||
"metadata": {}, | ||
"source": [ | ||
"##### Load Fine-Tuned Torch LM Results" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "26d15d8e-81cc-4cd5-a2dc-765325cb5a55", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# # ls llama 3 8b\n", | ||
"# with open(\"ls-Meta-Llama-3-8B-msp-v2-mdace-20_raw_labels.pkl\", \"rb\") as f:\n", | ||
"# ls_llama_8b_last_labels = pickle.load(f)\n", | ||
"# with open(\"ls-Meta-Llama-3-8B-msp-v2-mdace-20_scores.pkl\", \"rb\") as f:\n", | ||
"# ls_llama_8b_last_scores = pickle.load(f)\n", | ||
"\n", | ||
"# ls_llama_8b_last_scores_transformed = torch.sigmoid(torch.tensor(ls_llama_8b_last_scores))\n", | ||
"# all_model_names.append(\"LS Llama-3 8B (Last)\")\n", | ||
"# all_y_trues.append(ls_llama_8b_last_labels)\n", | ||
"# all_y_scores.append(ls_llama_8b_last_scores_transformed)\n", | ||
"# all_colors.append('#ab20fd')\n", | ||
"\n", | ||
"# ls unllama 3 8b\n", | ||
"with open(\"ls-unllama-Meta-Llama-3-8B-msp-v2-mdace-20_raw_labels.pkl\", \"rb\") as f:\n", | ||
" ls_unllama_8b_max_labels = pickle.load(f)\n", | ||
"with open(\"ls-unllama-Meta-Llama-3-8B-msp-v2-mdace-20_raw_scores.pkl\", \"rb\") as f:\n", | ||
" ls_unllama_8b_max_scores = pickle.load(f)\n", | ||
"\n", | ||
"ls_unllama_8b_max_scores_transformed = torch.sigmoid(torch.tensor(ls_unllama_8b_max_scores)).numpy()\n", | ||
"all_model_names.append(\"LS UnLlama-3 8B (Max)\")\n", | ||
"all_y_trues.append(ls_unllama_8b_max_labels)\n", | ||
"all_y_scores.append(ls_unllama_8b_max_scores_transformed)\n", | ||
"\n", | ||
"# BELT Max 5 segments\n", | ||
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_128_max_5_labels.pkl\", \"rb\") as f:\n", | ||
" belt_5_max_labels = pickle.load(f)\n", | ||
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_128_max_5_scores.pkl\", \"rb\") as f:\n", | ||
" belt_5_max_scores = pickle.load(f)\n", | ||
"\n", | ||
"all_model_names.append(\"BELT 128 step 5 seg (Max)\")\n", | ||
"all_y_trues.append(belt_5_max_labels)\n", | ||
"all_y_scores.append(belt_5_max_scores)\n", | ||
"\n", | ||
"# BELT Max 128 segments\n", | ||
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_448_max_128_labels.pkl\", \"rb\") as f:\n", | ||
" belt_128_max_labels = pickle.load(f)\n", | ||
"with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_448_max_128_scores.pkl\", \"rb\") as f:\n", | ||
" belt_128_max_scores = pickle.load(f)\n", | ||
"\n", | ||
"all_model_names.append(\"BELT 448 step 128 seg (Max)\")\n", | ||
"all_y_trues.append(belt_128_max_labels)\n", | ||
"all_y_scores.append(belt_128_max_scores)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6f21cd42", | ||
"metadata": {}, | ||
"source": [ | ||
"##### Print Performance for all Metrics for all Models" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "2665e9ce-07c4-4196-9fe7-7f911123b8f9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def print_mean_ci_of_metric_list(metric_list, metric_name, num_std):\n", | ||
" mean_metric = np.mean(metric_list)\n", | ||
" std_metric = np.std(metric_list)\n", | ||
" metric_low = np.maximum(mean_metric - std_metric * num_std, 0)\n", | ||
" metric_high = np.minimum(mean_metric + std_metric * num_std, 1)\n", | ||
"\n", | ||
" print(\n", | ||
" f\"{metric_name}: {round(mean_metric, 3)} ([{round(metric_low, 3)} - {round(metric_high, 3)}] 95% CI)\"\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "4d6ff4a8", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Results for LS UnLlama-3 8B (Max)\n", | ||
"\n", | ||
"Micro Average Precision: 0.277 ([0.256 - 0.299] 95% CI)\n", | ||
"Micro ROC AUC: 0.828 ([0.818 - 0.839] 95% CI)\n", | ||
"\n", | ||
"Results for BELT 128 step 5 seg (Max)\n", | ||
"\n", | ||
"Micro Average Precision: 0.707 ([0.698 - 0.716] 95% CI)\n", | ||
"Micro ROC AUC: 0.942 ([0.94 - 0.944] 95% CI)\n", | ||
"\n", | ||
"Results for BELT 448 step 128 seg (Max)\n", | ||
"\n", | ||
"Micro Average Precision: 0.804 ([0.797 - 0.812] 95% CI)\n", | ||
"Micro ROC AUC: 0.971 ([0.969 - 0.972] 95% CI)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"model2metric_df = {}\n", | ||
"for y_trues, y_scores, name in zip(\n", | ||
" all_y_trues, all_y_scores, all_model_names\n", | ||
"):\n", | ||
" \n", | ||
" micro_aps, macro_aps, micro_roc_aucs, macro_roc_aucs = [], [], [], []\n", | ||
" for i in range(num_bootstrap):\n", | ||
" \n", | ||
" # Sample N records with replacement where N is the total number of records\n", | ||
" sample_indices = np.random.choice(len(y_trues), len(y_trues))\n", | ||
" sample_labels = np.array(y_trues)[sample_indices]\n", | ||
" sample_scores = np.array(y_scores)[sample_indices]\n", | ||
" \n", | ||
" micro_ap = metrics.average_precision_score(y_true=sample_labels, y_score=sample_scores, average='micro')\n", | ||
" micro_aps.append(micro_ap)\n", | ||
"\n", | ||
" # macro_ap = metrics.average_precision_score(y_true=sample_labels, y_score=sample_scores, average='macro')\n", | ||
" # macro_aps.append(macro_ap)\n", | ||
"\n", | ||
" micro_roc_auc = metrics.roc_auc_score(y_true=sample_labels, y_score=sample_scores, average='micro')\n", | ||
" micro_roc_aucs.append(micro_roc_auc)\n", | ||
"\n", | ||
" # macro_roc_auc = metrics.roc_auc_score(y_true=sample_labels, y_score=sample_scores, average='macro')\n", | ||
" # macro_roc_aucs.append(macro_roc_auc)\n", | ||
" \n", | ||
" metric_df = pd.DataFrame({\n", | ||
" \"micro_aps\": micro_aps,\n", | ||
" \"micro_roc_aucs\": micro_roc_aucs,\n", | ||
" })\n", | ||
" model2metric_df[name] = metric_df\n", | ||
"\n", | ||
" print(f\"\\nResults for {name}\\n\")\n", | ||
" print_mean_ci_of_metric_list(micro_aps, metric_name=\"Micro Average Precision\", num_std=num_std)\n", | ||
" print_mean_ci_of_metric_list(micro_roc_aucs, metric_name=\"Micro ROC AUC\", num_std=num_std)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3.10 - SDK v2", | ||
"language": "python", | ||
"name": "python310-sdkv2" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.11" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.