diff --git a/llama/README.md b/llama/README.md new file mode 100644 index 0000000..c64ffd0 --- /dev/null +++ b/llama/README.md @@ -0,0 +1 @@ +Code modified from: https://github.com/4AI/LS-LLaMA/tree/main \ No newline at end of file diff --git a/llama/evaluate_models.ipynb b/llama/evaluate_models.ipynb new file mode 100644 index 0000000..ef15cf5 --- /dev/null +++ b/llama/evaluate_models.ipynb @@ -0,0 +1,262 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5d7b1d18", + "metadata": {}, + "source": [ + "### Evaluate Models\n" + ] + }, + { + "cell_type": "markdown", + "id": "80b1b2c7", + "metadata": {}, + "source": [ + "##### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "87dc70f1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/anaconda/envs/azureml_py310_sdkv2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import torch\n", + "import pickle\n", + "import json\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import torch.nn.functional as F\n", + "from sklearn import metrics" + ] + }, + { + "cell_type": "markdown", + "id": "ed3988f7", + "metadata": {}, + "source": [ + "##### Evaluation Parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2a1fc80e", + "metadata": {}, + "outputs": [], + "source": [ + "threshold = 0.5 # currently we don't maximize val f1 to find the threshold... need to grab scores for all the val sets if we do this\n", + "num_std = 1.96\n", + "num_bootstrap = 1000\n", + "line_width = 2\n", + "alpha = 0.2\n", + "font_size = 16\n", + "legend_size = 10\n", + "x_size = 10\n", + "y_size = 10" + ] + }, + { + "cell_type": "markdown", + "id": "1ac2d927-76e0-48ab-8777-48bc70206d07", + "metadata": {}, + "source": [ + "##### Initialize Score, Model, and Color Arrays" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "acc793b2-80d0-45ac-9a3a-15dcf8fb53fb", + "metadata": {}, + "outputs": [], + "source": [ + "# Define master lists of labels, scores, names, and colors\n", + "all_y_trues, all_y_scores, all_model_names, all_colors = [], [], [], []" + ] + }, + { + "cell_type": "markdown", + "id": "aef4eaf2-fff5-4f83-8ade-2367a2513aa8", + "metadata": {}, + "source": [ + "##### Load Fine-Tuned Torch LM Results" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "26d15d8e-81cc-4cd5-a2dc-765325cb5a55", + "metadata": {}, + "outputs": [], + "source": [ + "# # ls llama 3 8b\n", + "# with open(\"ls-Meta-Llama-3-8B-msp-v2-mdace-20_raw_labels.pkl\", \"rb\") as f:\n", + "# ls_llama_8b_last_labels = pickle.load(f)\n", + "# with open(\"ls-Meta-Llama-3-8B-msp-v2-mdace-20_scores.pkl\", \"rb\") as f:\n", + "# ls_llama_8b_last_scores = pickle.load(f)\n", + "\n", + "# ls_llama_8b_last_scores_transformed = torch.sigmoid(torch.tensor(ls_llama_8b_last_scores))\n", + "# all_model_names.append(\"LS Llama-3 8B (Last)\")\n", + "# all_y_trues.append(ls_llama_8b_last_labels)\n", + "# all_y_scores.append(ls_llama_8b_last_scores_transformed)\n", + "# all_colors.append('#ab20fd')\n", + "\n", + "# ls unllama 3 8b\n", + "with open(\"ls-unllama-Meta-Llama-3-8B-msp-v2-mdace-20_raw_labels.pkl\", \"rb\") as f:\n", + " ls_unllama_8b_max_labels = pickle.load(f)\n", + "with open(\"ls-unllama-Meta-Llama-3-8B-msp-v2-mdace-20_raw_scores.pkl\", \"rb\") as f:\n", + " ls_unllama_8b_max_scores = pickle.load(f)\n", + "\n", + "ls_unllama_8b_max_scores_transformed = torch.sigmoid(torch.tensor(ls_unllama_8b_max_scores)).numpy()\n", + "all_model_names.append(\"LS UnLlama-3 8B (Max)\")\n", + "all_y_trues.append(ls_unllama_8b_max_labels)\n", + "all_y_scores.append(ls_unllama_8b_max_scores_transformed)\n", + "\n", + "# BELT Max 5 segments\n", + "with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_128_max_5_labels.pkl\", \"rb\") as f:\n", + " belt_5_max_labels = pickle.load(f)\n", + "with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_128_max_5_scores.pkl\", \"rb\") as f:\n", + " belt_5_max_scores = pickle.load(f)\n", + "\n", + "all_model_names.append(\"BELT 128 step 5 seg (Max)\")\n", + "all_y_trues.append(belt_5_max_labels)\n", + "all_y_scores.append(belt_5_max_scores)\n", + "\n", + "# BELT Max 128 segments\n", + "with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_448_max_128_labels.pkl\", \"rb\") as f:\n", + " belt_128_max_labels = pickle.load(f)\n", + "with open(\"./BELT-BASELINE/bioclinicalroberta_belt_mdace20_510_step_448_max_128_scores.pkl\", \"rb\") as f:\n", + " belt_128_max_scores = pickle.load(f)\n", + "\n", + "all_model_names.append(\"BELT 448 step 128 seg (Max)\")\n", + "all_y_trues.append(belt_128_max_labels)\n", + "all_y_scores.append(belt_128_max_scores)" + ] + }, + { + "cell_type": "markdown", + "id": "6f21cd42", + "metadata": {}, + "source": [ + "##### Print Performance for all Metrics for all Models" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2665e9ce-07c4-4196-9fe7-7f911123b8f9", + "metadata": {}, + "outputs": [], + "source": [ + "def print_mean_ci_of_metric_list(metric_list, metric_name, num_std):\n", + " mean_metric = np.mean(metric_list)\n", + " std_metric = np.std(metric_list)\n", + " metric_low = np.maximum(mean_metric - std_metric * num_std, 0)\n", + " metric_high = np.minimum(mean_metric + std_metric * num_std, 1)\n", + "\n", + " print(\n", + " f\"{metric_name}: {round(mean_metric, 3)} ([{round(metric_low, 3)} - {round(metric_high, 3)}] 95% CI)\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4d6ff4a8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Results for LS UnLlama-3 8B (Max)\n", + "\n", + "Micro Average Precision: 0.277 ([0.256 - 0.299] 95% CI)\n", + "Micro ROC AUC: 0.828 ([0.818 - 0.839] 95% CI)\n", + "\n", + "Results for BELT 128 step 5 seg (Max)\n", + "\n", + "Micro Average Precision: 0.707 ([0.698 - 0.716] 95% CI)\n", + "Micro ROC AUC: 0.942 ([0.94 - 0.944] 95% CI)\n", + "\n", + "Results for BELT 448 step 128 seg (Max)\n", + "\n", + "Micro Average Precision: 0.804 ([0.797 - 0.812] 95% CI)\n", + "Micro ROC AUC: 0.971 ([0.969 - 0.972] 95% CI)\n" + ] + } + ], + "source": [ + "model2metric_df = {}\n", + "for y_trues, y_scores, name in zip(\n", + " all_y_trues, all_y_scores, all_model_names\n", + "):\n", + " \n", + " micro_aps, macro_aps, micro_roc_aucs, macro_roc_aucs = [], [], [], []\n", + " for i in range(num_bootstrap):\n", + " \n", + " # Sample N records with replacement where N is the total number of records\n", + " sample_indices = np.random.choice(len(y_trues), len(y_trues))\n", + " sample_labels = np.array(y_trues)[sample_indices]\n", + " sample_scores = np.array(y_scores)[sample_indices]\n", + " \n", + " micro_ap = metrics.average_precision_score(y_true=sample_labels, y_score=sample_scores, average='micro')\n", + " micro_aps.append(micro_ap)\n", + "\n", + " # macro_ap = metrics.average_precision_score(y_true=sample_labels, y_score=sample_scores, average='macro')\n", + " # macro_aps.append(macro_ap)\n", + "\n", + " micro_roc_auc = metrics.roc_auc_score(y_true=sample_labels, y_score=sample_scores, average='micro')\n", + " micro_roc_aucs.append(micro_roc_auc)\n", + "\n", + " # macro_roc_auc = metrics.roc_auc_score(y_true=sample_labels, y_score=sample_scores, average='macro')\n", + " # macro_roc_aucs.append(macro_roc_auc)\n", + " \n", + " metric_df = pd.DataFrame({\n", + " \"micro_aps\": micro_aps,\n", + " \"micro_roc_aucs\": micro_roc_aucs,\n", + " })\n", + " model2metric_df[name] = metric_df\n", + "\n", + " print(f\"\\nResults for {name}\\n\")\n", + " print_mean_ci_of_metric_list(micro_aps, metric_name=\"Micro Average Precision\", num_std=num_std)\n", + " print_mean_ci_of_metric_list(micro_roc_aucs, metric_name=\"Micro ROC AUC\", num_std=num_std)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10 - SDK v2", + "language": "python", + "name": "python310-sdkv2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/llama/ls_llama_seq_clf.py b/llama/ls_llama_seq_clf.py new file mode 100644 index 0000000..ad126bc --- /dev/null +++ b/llama/ls_llama_seq_clf.py @@ -0,0 +1,238 @@ +# -*- coding: utf-8 -*- + +import os +import sys +import yaml +import time +import torch +import pickle +import logging +import transformers +from datasets import DatasetDict, Dataset +from typing import List, Any, Dict +from datasets import load_dataset, load_from_disk +from transformers.data import DataCollatorWithPadding +from transformers import TrainingArguments, Trainer, EarlyStoppingCallback, AutoTokenizer +from peft import get_peft_model, LoraConfig, TaskType +import evaluate +import numpy as np +import pandas as pd +from scipy.special import expit +from sklearn.metrics import average_precision_score + +from utils import check_empty_count_gpus, create_current_run, create_log_dir +from modeling_llama_local import LlamaForSequenceClassification +from modeling_unllama import UnmaskingLlamaForSequenceClassification + +os.environ["HF_EVALUATE_OFFLINE"] = "1" +os.environ["HF_DATASETS_OFFLINE"] = "1" + +# Load Run Parameters +with open("params.yml", "r") as stream: + PARAMS = yaml.safe_load(stream) + +batch_size = PARAMS["batch_size"] +gradient_accumulation_steps = PARAMS["gradient_accumulation_steps"] +learning_rate = PARAMS["learning_rate"] +lora_r = PARAMS["lora_r"] +lora_a = PARAMS["lora_a"] +max_length = PARAMS["max_length"] +warmup_steps = PARAMS["warmup_steps"] +eval_steps = PARAMS["eval_steps"] +save_steps = PARAMS["save_steps"] +logging_steps = PARAMS["logging_steps"] +early_stopping_patience = PARAMS["early_stopping_patience"] +pooling_strategy = PARAMS["pooling_strategy"] +dataset_path = PARAMS["dataset_path"] +train = PARAMS["train"] +resume_training = PARAMS["resume_training"] +resume_checkpoint = PARAMS["resume_checkpoint"] +test_checkpoint = PARAMS["test_checkpoint"] +model_id = PARAMS["model_id"] +output_path = PARAMS["output_path"] +model_name = PARAMS["model_name"] +unllama = PARAMS["unllama"] +id2label = PARAMS["id2label"] + +label2id = {v: k for k, v in id2label.items()} +ds = load_from_disk(dataset_path) + +# # This is to avoid using a map function which seems to be unreliable when used +# # in combination with the preprocess_function. I should understand this better, +# # but I'm just using Pandas for now to ensure we properly transform the label column. +# def wrap_label_column(dataset): +# df = dataset.to_pandas() +# df['label'] = df['label'].apply(lambda x: [int(x)]) +# return Dataset.from_pandas(df) + +# ds = DatasetDict({ +# 'train': wrap_label_column(ds['train']), +# 'val': wrap_label_column(ds['val']), +# 'test': wrap_label_column(ds['test']), +# }) + +# Define Logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Log CUDNN, PyTorch, and Transformers versions +logger.info(f"CUDNN version: {torch.backends.cudnn.version()}") +logger.info(f"Torch version: {torch.__version__}") +logger.info(f"Transformers version: {transformers.__version__}") + +# Check, Empty, and Count GPUs +check_empty_count_gpus(logger=logger) + +# Load tokenizer +tokenizer = AutoTokenizer.from_pretrained(PARAMS["tokenizer_id"]) # hot fix + +# llama doesn't have a pad token so we add one as the eos token +tokenizer.pad_token = tokenizer.eos_token + +# Only create a run directory if training a new model +if train: + + # Create Run Directory + current_run_dir = create_current_run( + save_path=output_path, params=PARAMS, logger=logger + ) + logger.info(f"Created run directory: {current_run_dir}.") + + # Create logging dir + logging_dir = create_log_dir(current_run_dir) + + # Set Run Name + run_name = current_run_dir.split("/")[-1] + logger.info(f"Starting run {run_name}...") + +def compute_metrics(eval_pred): + predictions, labels = eval_pred + sigmoid_predictions = expit(predictions) + micro_avg_pr_auc = average_precision_score(labels, sigmoid_predictions, average='micro') + return {"micro_avg_pr_auc": micro_avg_pr_auc} + +def preprocess_function(examples): + + return tokenizer(examples["text"], padding='longest', max_length=max_length, truncation=True) + +tokenized_ds = ds.map(preprocess_function, batched=True) + +# this is messing with things: https://huggingface.co/docs/transformers/en/main_classes/data_collator +# data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + +df = tokenized_ds['train'].to_pandas() +with pd.option_context('display.max_rows', None, 'display.max_columns', None): + print(df.head()) + +# Train +if train: + + if unllama: + model = UnmaskingLlamaForSequenceClassification.from_pretrained(model_id, num_labels=len(label2id), id2label=id2label, label2id=label2id).bfloat16() + model.set_pooling(pooling_strategy) + else: + model = LlamaForSequenceClassification.from_pretrained(model_id, num_labels=len(label2id)).bfloat16() + + # set the pad token of the model's configuration + # https://stackoverflow.com/questions/68084302/assertionerror-cannot-handle-batch-sizes-1-if-no-padding-token-is-defined + model.config.pad_token_id = model.config.eos_token_id + + # # Set problem type explicitly + # model.config.problem_type = "binary_classification" + + peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=lora_r, lora_alpha=lora_a, lora_dropout=0.1) + model = get_peft_model(model, peft_config) + + training_args = TrainingArguments( + output_dir=current_run_dir, + max_steps=1000000000, + learning_rate=learning_rate, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + metric_for_best_model="eval_loss", + evaluation_strategy="steps", + save_strategy="steps", + eval_steps=eval_steps, + save_steps=save_steps, + warmup_steps=warmup_steps, + logging_steps=logging_steps, + logging_dir=logging_dir, + lr_scheduler_type='linear', + weight_decay=0.01, + adam_beta1=0.9, + adam_beta2=0.999, + adam_epsilon=0.00000001, + optim="adamw_torch", + load_best_model_at_end=True, + push_to_hub=False, + bf16=True, + bf16_full_eval=True, + gradient_checkpointing=True, + label_names='label' + ) + + # Define early stopping callback + early_stopping = EarlyStoppingCallback(early_stopping_patience=early_stopping_patience) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_ds["train"], + eval_dataset=tokenized_ds["val"], + callbacks=[early_stopping], + tokenizer=tokenizer, + compute_metrics=compute_metrics, + ) + + # data_collator=data_collator, + + # Start training timer + start_time = time.time() + + # Start from model provided above and new training parameters defined above + if not resume_training: + trainer.train() + + # Resume using model and training parameters defined in checkpoint + else: + trainer.train(resume_checkpoint) + + # Log training time + end_time = time.time() + execution_time_hours = round((end_time - start_time) / 3600.0, 2) + logger.info(f"Training took {execution_time_hours} hours.") + + # Save best model + trainer.model.save_pretrained( + os.path.join(current_run_dir, f'{model_name}_model') + ) + tokenizer.save_pretrained( + os.path.join(current_run_dir, f'{model_name}_tokenizer') + ) + +# Test +else: + + if unllama: + model = UnmaskingLlamaForSequenceClassification.from_pretrained(test_checkpoint, num_labels=len(label2id), id2label=id2label, label2id=label2id).bfloat16() + model.set_pooling(pooling_strategy) + else: + model = LlamaForSequenceClassification.from_pretrained(test_checkpoint, num_labels=len(label2id)).bfloat16() + + peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=True, r=lora_r, lora_alpha=lora_a, lora_dropout=0.1) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + trainer = Trainer(model=model) + +# Predict on test data +# Apply softmax or sigmoid to these outputs! +output = trainer.predict(tokenized_ds["test"]) +labels = output.label_ids +probs = torch.tensor(output.predictions) + +with open(f"./{model_name}_raw_scores.pkl", "wb") as f: + pickle.dump(probs.cpu().detach().numpy(), f) +with open(f"./{model_name}_raw_labels.pkl", "wb") as f: + pickle.dump(labels, f) \ No newline at end of file diff --git a/llama/modeling_llama_local.py b/llama/modeling_llama_local.py new file mode 100644 index 0000000..7f6d801 --- /dev/null +++ b/llama/modeling_llama_local.py @@ -0,0 +1,1259 @@ +# Copied by Joel in May of 2024 +# transformers==4.35.1 + +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from transformers.models.llama.configuration_llama import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + warnings.warn( + "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils.AttentionMaskConverter._prepare_4d_attention_mask" + ) + return AttentionMaskConverter._prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) + + +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + warnings.warn( + "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask" + ) + return AttentionMaskConverter._make_causal_mask( + input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # TODO: llama does not have dropout in the config?? + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=self.is_causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = ( + LlamaAttention(config=config) + if not getattr(config, "_flash_attn_2_enabled", False) + else LlamaFlashAttention2(config=config) + ) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + # print(labels) + # print(labels.dtype) + self.config.problem_type = "single_label_classification" + else: + # print("Running multi-label classification forward pass...") # added 05-15-24 + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + # print(labels) + # print(labels.dtype) + # print(pooled_logits) + # print(pooled_logits.dtype) + # print("Running multi-label classification forward pass...") # added 05-15-24 + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels.float()) # changed labels to labels.float() on 05-15-24 + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file diff --git a/llama/modeling_unllama.py b/llama/modeling_unllama.py new file mode 100644 index 0000000..f8cfa73 --- /dev/null +++ b/llama/modeling_unllama.py @@ -0,0 +1,546 @@ +# -*- coding: utf-8 -*- + +from copy import deepcopy + +from transformers.models.llama.modeling_llama import * +from transformers.modeling_outputs import TokenClassifierOutput + + +# _CONFIG_FOR_DOC = "MeditronLlamaConfig" + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class UnmaskingLlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + # causal mask + ''' + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + print('unmasking attention mask:') + print(attention_mask) + ''' + # remove causal mask + attention_mask = torch.zeros( + (batch_size, 1, seq_length, seq_length), device=inputs_embeds.device + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class UnmaskingLlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = UnmaskingLlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + self.pooling = 'mean' + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def set_pooling(self, pooling): + self.pooling = pooling + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + if self.pooling == 'last': + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + elif self.pooling == 'max': + pooled_logits, _ = torch.max(logits, dim=1) + elif self.pooling == 'mean': + pooled_logits = torch.mean(logits, dim=1) + else: + raise NotImplementedError + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + # if self.num_labels == 1: + # self.config.problem_type = "binary_classification" + # print("Binary classification problem type") + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + # if self.config.problem_type == "binary_classification": + # loss_fct = BCEWithLogitsLoss() + # loss = loss_fct(pooled_logits, labels) + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels.float()) # changed labels to labels.float() on 05-22-24 + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a token classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.dropout = nn.Dropout(0.1) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a token classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class UnmaskingLlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = UnmaskingLlamaModel(config) + self.dropout = nn.Dropout(0.1) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + diff --git a/llama/params.yml b/llama/params.yml new file mode 100644 index 0000000..cc39e64 --- /dev/null +++ b/llama/params.yml @@ -0,0 +1,29 @@ + +dataset_path: '../text_label.hf' +tokenizer_id: 'Meta-Llama-3-8B/mlflow_model_folder/data/model' +model_id: 'mlflow_model_folder/data/model' +output_path: 'output' +model_name: 'Meta-Llama-3-8B' + +train: True +resume_training: False +resume_checkpoint: "" +test_checkpoint: "" +unllama: False # controls autoregressive masking + +batch_size: 2 +gradient_accumulation_steps: 64 +early_stopping_patience: 10 +learning_rate: 0.0001 +lora_r: 16 +lora_a: 32 +max_length: 2048 # was 8192 +warmup_steps: 500 +eval_steps: 200 +save_steps: 200 +logging_steps: 200 +pooling_strategy: 'max' # use max for unllama, LlamaForSequenceClassification uses the last token to do the classification, as other causal models (e.g. GPT-2) do + +id2label: + 0: "Negative" + 1: "Positive" diff --git a/llama/plot_trainer_state.py b/llama/plot_trainer_state.py new file mode 100644 index 0000000..e77c916 --- /dev/null +++ b/llama/plot_trainer_state.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Given the path to a trainer state from Hugging Face transformers, plot the learning curves and learning rate. +Plots are saved to the checkpoint directory as well as the working directory for easy access. +This script assumes the number of logging steps, save steps, and eval steps is equal. +It also assumes that validation loss is used as the monitoring metric for checkpointing. +""" + +# Open imports +import os +import json +import yaml +import matplotlib.pyplot as plt + +# Set checkpoint path to evaluate +CHECKPOINT_PATH = ( + "../ls-Meta-Llama-3-8B-msp-v2-mdace-20/run_2/checkpoint-8000" +) + + + +def main(): + + # Get working dir + working_dir = os.getcwd() + + # Load trainer state + trainer_state_file = os.path.join(CHECKPOINT_PATH, "trainer_state.json") + with open(trainer_state_file) as f: + states = json.load(f) + + # Get global step and best val loss + global_step = states["global_step"] + best_val_loss = states["best_metric"] + + # Get log history + steps = [] + train_loss = [] + val_loss = [] + lrs = [] + for i, state in enumerate(states["log_history"]): + + # Every other entry in the log history + # Contains the training info + if i % 2 == 0: + steps.append(state["step"]) + train_loss.append(state["loss"]) + lrs.append(state["learning_rate"]) + else: + val_loss.append(state["eval_loss"]) + + # Plot train and eval loss + fig, ax = plt.subplots(figsize=(8, 8)) + ax.plot(steps, train_loss, label="Train Loss") + ax.plot(steps, val_loss, label="Val Loss") + ax.legend(loc="best", prop={"size": 10}) + plt.title(f"Learning Curves at Checkpoint {global_step}", fontsize=10) + plt.xlabel("Step", fontsize=10) + plt.ylabel("Loss", fontsize=10) + plt.tight_layout() + plt.savefig(os.path.join(CHECKPOINT_PATH, "learning_curves.png")) + plt.savefig(os.path.join(working_dir, "learning_curves.png")) + print(f"Best validation loss = {best_val_loss}.") + + # Plot learning rate + fig, ax = plt.subplots(figsize=(8, 8)) + ax.plot(steps, lrs, label="Learning Rate") + ax.legend(loc="best", prop={"size": 10}) + plt.title(f"Learning Rates at Checkpoint {global_step}", fontsize=10) + plt.xlabel("Step", fontsize=10) + plt.ylabel("Learning Rate", fontsize=10) + plt.tight_layout() + plt.savefig(os.path.join(CHECKPOINT_PATH, "learning_rates.png")) + plt.savefig(os.path.join(working_dir, "learning_rates.png")) + print(f"Current learning rate = {lrs[-1]}.") + + +if __name__ == "__main__": + + main() diff --git a/llama/requirements.txt b/llama/requirements.txt new file mode 100644 index 0000000..c4902e2 --- /dev/null +++ b/llama/requirements.txt @@ -0,0 +1,7 @@ +datasets +evaluate +numpy +peft +transformers==4.35.1 +sentencepiece==0.1.99 +torch==1.13.1 \ No newline at end of file diff --git a/llama/unllama_token_clf.py b/llama/unllama_token_clf.py new file mode 100644 index 0000000..ed7c3d0 --- /dev/null +++ b/llama/unllama_token_clf.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- + +import sys +import json +import numpy as np +import evaluate +from datasets import load_dataset, Dataset, DatasetDict +from transformers import AutoTokenizer +from transformers import DataCollatorForTokenClassification +from transformers import TrainingArguments, Trainer +from peft import get_peft_model, LoraConfig, TaskType + +from modeling_llama import UnmaskingLlamaForTokenClassification + + +def load_ontonotesv5(): + ret = {} + for split_name in ['train', 'dev', 'test']: + data = [] + with open(f'./data/ontonotesv5/{split_name}.jsonl', 'r') as reader: + for line in reader: + data.append(json.loads(line)) + ret[split_name] = Dataset.from_list(data) + return DatasetDict(ret) + + +if len(sys.argv) != 3: + print('usage python %.py task model_size') + sys.exit() + +task, model_size = sys.argv[1], sys.argv[2].lower() +print(f'handling task {task}') + +epochs = 10 +batch_size = 8 +learning_rate = 1e-4 +max_length = 64 +lora_r = 12 +if model_size == '7b': + model_id = 'NousResearch/Llama-2-7b-hf' +elif model_size == '13b': + model_id = 'NousResearch/Llama-2-13b-hf' +else: + raise NotImplementedError +tokenizer = AutoTokenizer.from_pretrained(model_id) +seqeval = evaluate.load("seqeval") +if task == 'wnut_17': + ds = load_dataset("wnut_17") + label2id = { "O": 0, "B-corporation": 1, "I-corporation": 2, "B-creative-work": 3, "I-creative-work": 4, "B-group": 5, "I-group": 6, "B-location": 7, "I-location": 8, "B-person": 9, "I-person": 10, "B-product": 11, "I-product": 12, } +elif task == 'conll2003': + ds = load_dataset("conll2003") + label2id = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8} +elif task == 'ontonotesv5': + ds = load_ontonotesv5() + label2id = {'O': 0, 'B-NORP': 1, 'B-PERSON': 2, 'B-WORK_OF_ART': 3, 'B-QUANTITY': 4, 'B-EVENT': 5, 'B-DATE': 6, 'B-TIME': 7, 'B-PERCENT': 8, 'B-LANGUAGE': 9, 'B-ORG': 10, 'B-CARDINAL': 11, 'B-LAW': 12, 'B-GPE': 13, 'B-PRODUCT': 14, 'B-LOC': 15, 'B-MONEY': 16, 'B-ORDINAL': 17, 'B-FAC': 18} +else: + raise NotImplementedError +id2label = {v: k for k, v in label2id.items()} +label_list = list(label2id.keys()) # ds["train"].features[f"ner_tags"].feature.names +model = UnmaskingLlamaForTokenClassification.from_pretrained( + model_id, num_labels=len(label2id), id2label=id2label, label2id=label2id +).bfloat16() +peft_config = LoraConfig(task_type=TaskType.TOKEN_CLS, inference_mode=False, r=lora_r, lora_alpha=32, lora_dropout=0.1) +model = get_peft_model(model, peft_config) +model.print_trainable_parameters() + + +def tokenize_and_align_labels(examples): + tokenized_inputs = tokenizer(examples["tokens"], is_split_into_words=True, padding='longest', max_length=max_length, truncation=True) + + labels = [] + for i, label in enumerate(examples[f"ner_tags"]): + word_ids = tokenized_inputs.word_ids(batch_index=i) # Map tokens to their respective word. + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: # Set the special tokens to -100. + if word_idx is None: + label_ids.append(-100) + elif word_idx != previous_word_idx: # Only label the first token of a given word. + label_ids.append(label[word_idx]) + else: + label_ids.append(-100) + previous_word_idx = word_idx + labels.append(label_ids) + + tokenized_inputs["labels"] = labels + return tokenized_inputs + + +tokenized_ds = ds.map(tokenize_and_align_labels, batched=True) +data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer) + + +def compute_metrics(p): + predictions, labels = p + predictions = np.argmax(predictions, axis=2) + + true_predictions = [ + [label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + + results = seqeval.compute(predictions=true_predictions, references=true_labels) + return { + "precision": results["overall_precision"], + "recall": results["overall_recall"], + "f1": results["overall_f1"], + "accuracy": results["overall_accuracy"], + } + + +training_args = TrainingArguments( + output_dir="my_awesome_ds_model", + learning_rate=learning_rate, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + num_train_epochs=epochs, + weight_decay=0.01, + evaluation_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + push_to_hub=False, +) + +trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_ds["train"], + eval_dataset=tokenized_ds["test"], + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics, +) + +trainer.train() diff --git a/llama/utils.py b/llama/utils.py new file mode 100644 index 0000000..de73f3d --- /dev/null +++ b/llama/utils.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Utility functions for deep learning experiments +""" + +import os +import yaml +import torch +import numpy as np +from sklearn.multiclass import OneVsRestClassifier + + +def create_current_run(save_path, params, logger=None): + """ + Create a directory for the current run, save the + current pipeline parameters, and return the + path to the current run directory. + """ + + # Create current run dir + src_dirs = os.listdir(save_path) + max_run = ( + max([int(dir.split("_")[1]) for dir in src_dirs]) if len(src_dirs) > 0 else -1 + ) + current_run_dir = os.path.join(save_path, "run_" + str(max_run + 1) + "/") + os.makedirs(current_run_dir) + + if logger: + logger.info(f"Created current run dir: {current_run_dir}.") + + # Save run params in current run dir for reference + with open(os.path.join(current_run_dir, "params.yml"), "w") as stream: + yaml.dump(params, stream, default_flow_style=False) + + if logger: + logger.info(f"Saved run parameter to current run dir.") + + return current_run_dir + + +def create_log_dir(current_run_dir, logger=None): + + logging_dir = os.path.join(current_run_dir, "logs/") + + if not os.path.exists(logging_dir): + os.makedirs(logging_dir) + + if logger: + logger.info(f"Created logging directory: {logging_dir}.") + + +def check_empty_count_gpus(logger=None): + """ + Check that GPU is available, empty the cache, + and count the number of available devices. + """ + + # Check that a GPU is available: + assert torch.cuda.is_available(), "No GPU found. Please run on a GPU." + + # Empty GPU cache + torch.cuda.empty_cache() + + # Count available devices + device_count = torch.cuda.device_count() + + if logger: + logger.info(f"Found {device_count} GPU(s)!") + + +def convert_1d_binary_labels_to_2d(labels): + """ + Convert 1D binary labels to a 2D representation. + """ + + # Convert a 1D, binary label array to 2D + if isinstance(labels[0], np.integer) or isinstance(labels[0], int): + + # Check that we have a 1D array of 1s and 0s + assert len(np.array(labels).shape), "Expected labels to be 1D." + assert all( + x == 0 or x == 1 for x in labels + ), "Expected only 1s and 0s in labels." + + # Convert to 2D representation + new_labels = np.zeros(shape=(len(labels), 2)) + for i, target in enumerate(labels): + if target == 0: + new_labels[i] = [1, 0] + elif target == 1: + new_labels[i] = [0, 1] + else: + raise ValueError(f"Unexpected target: {target}.") + + return new_labels + + # Return 2D array + else: + + if isinstance(labels, (np.ndarray, np.generic)): + return labels + else: + return np.array(labels) + + +def make_lr_model_and_target_multi_class(model, y, class_strategy, n_jobs=-1): + """ + Given an sklearn LogisticRegression model and + a parameter indicating the multi-class training strategy + convert the model to a OneVsRestClassifier or + multinomial regression and return it with the + n_jobs parameter set to parallelize training. + Also returns the target array such that the final + return type is a tuple of (model, y) and y is + modified to use multi_class indices if + class_strategy='multi_class'. + """ + + if class_strategy == "multi_label": + + # Wrap model in OVR classifier + model = OneVsRestClassifier(model, n_jobs=n_jobs) + + elif class_strategy == "multi_class": + + # Set model attributes + model.multi_class = "multinomial" + model.n_jobs = n_jobs + + # Transform target array + y = transform_target_to_multi_class_indices(y) + + else: + + # Raise exception + raise ValueError( + f"Expected class_strategy to be one of ['multi_label', 'multi_class'] but got {class_strategy}." + ) + + return model, y + + +def transform_target_to_multi_class_indices(y): + """ + Given a 2d numpy array of one hot encoded + targets, return an array of the indices + representing the encoded label for each sample + as is required for sklearn multi-class classification. + """ + + return np.argmax(y, axis=1) diff --git a/long_roberta/README.md b/long_roberta/README.md new file mode 100644 index 0000000..924295f --- /dev/null +++ b/long_roberta/README.md @@ -0,0 +1,27 @@ +# torch_long_bert + +Code to fine-tune and evaluate long versions of BERT and BERT-like LMs from Hugging Face Transformers + +### Contents + +- [About](#about) +- [Environment](#environment) +- [Data Prep](#data-prep) +- [Train and Evaluate](#train-and-evaluate) + +### About + +This repository contains code to fine-tune and evaluate long versions of BERT and BERT-like LMs from Hugging Face Transformers using base PyTorch. The code in this directory has been modified from [this repository](https://github.com/mim-solutions/roberta_for_longer_texts) +and was originally written by [MichalBrzozowski91](https://github.com/MichalBrzozowski91) to implement [this suggestion](https://github.com/google-research/bert/issues/27#issuecomment-435265194) from [jacobdevlin-google](https://github.com/jacobdevlin-google). The core idea is to fine-tune a base BERT model by getting the representations from multiple concatenated windows of text with some overlap and applying sigmoid over each window to generate predictions. The final predictions are then taken as either the average or max value of the sigmoid output of all windows in a sample. + +### Environment + +To build the Python 3.10 environment required to run this code, create a Python 3.10 virtual environment with [Anaconda](https://www.anaconda.com/products/individual) and install the dependencies in `../requirements.txt`. + +### Data Prep + +This code takes as input a HuggingFace dataset with text and label columns. + +### Training + +Training and evaluation are combined into one script. After modifying `params.yml`, run `python train_and_evaluate.py` to fine-tune a long version of a BASE BERT model specified in `params.yml`. Make sure the BERT model you wish to fine-tune exists on the file system from which you run `train_and_evaluate.py`. Predictions on the test set are generated after every epoch but only used for the best model checkpoint to compute test set performance. This behavior could be adjusted to improve training efficiency, but because checkpoints are not actually saved, it would be necessary to implement checkpoint saving and loading in the code first. diff --git a/long_roberta/architecture.py b/long_roberta/architecture.py new file mode 100644 index 0000000..82b1e4f --- /dev/null +++ b/long_roberta/architecture.py @@ -0,0 +1,49 @@ +import yaml +import torch +import torch.nn as nn + + +class BERTSequenceClassificationHead(nn.Module): + + def __init__(self): + + super().__init__() + + with open("params.yml", "r") as stream: + params = yaml.safe_load(stream) + self.params = params + + self.out_proj = nn.Linear(params['linear_dim'], self.params['num_labels']) + self.sigmoid = nn.Sigmoid() + + def forward(self, cls_token_hidden_state): + + x = cls_token_hidden_state + x = self.out_proj(x) + x = self.sigmoid(x) + + return x + + +class BERTSequenceClassificationArch(nn.Module): + + def __init__(self, bert): + + super().__init__() + self.bert = bert + self.classification_head = BERTSequenceClassificationHead() + + def forward(self, input_ids, attention_mask): + + x = bert_vectorize(self.bert, input_ids, attention_mask) + x = self.classification_head(x) + return x + + +def bert_vectorize(bert, input_ids, attention_mask): + + outputs = bert(input_ids, attention_mask) + sequence_output = outputs[0] + + vectorized = sequence_output[:, 0, :] # take token (equiv. to [CLS]) + return vectorized diff --git a/long_roberta/base_model.py b/long_roberta/base_model.py new file mode 100644 index 0000000..b1df4ba --- /dev/null +++ b/long_roberta/base_model.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Base PyTorch model code for training and evaluation +""" + +import yaml +import numpy as np +import pandas as pd + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sklearn.metrics import f1_score +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler + + +class Model(): + """ + Abstract class for models + """ + + def __init__(self): + + with open("params.yml", "r") as stream: + params = yaml.safe_load(stream) + + self.params = params + self.preprocessor = None + self.dataset_class = None + self.collate_fn = None + + def evaluate_single_batch(self, batch, model, device): + + raise NotImplementedError("This is implemented for subclasses only") + + def create_dataset(self, X_preprocessed, y): + + dataset = self.dataset_class(X_preprocessed, y) + + return dataset + + def train_and_evaluate(self, X_train, X_val, X_test, y_train, y_val, y_test, epochs, early_stopping_epochs, logger): + + # Compute number of samples + number_of_train_samples = len(X_train) + number_of_val_samples = len(X_val) + number_of_test_samples = len(X_test) + + # Text preprocessing + X_train_preprocessed = self.preprocessor.preprocess(X_train) + X_val_preprocessed = self.preprocessor.preprocess(X_val) + X_test_preprocessed = self.preprocessor.preprocess(X_test) + + # Creating datasets + train_dataset = self.create_dataset(X_train_preprocessed, y_train) + val_dataset = self.create_dataset(X_val_preprocessed, y_val) + test_dataset = self.create_dataset(X_test_preprocessed, y_test) + + # Creating dataloaders + train_dataloader = create_train_dataloader( + train_dataset, self.params['batch_size'], self.collate_fn) + val_dataloader = create_train_dataloader( + val_dataset, self.params['batch_size'], self.collate_fn) + test_dataloader = create_test_dataloader( + test_dataset, self.params['batch_size'], self.collate_fn) + + # Training and evaluating + result = self.train_and_evaluate_preprocessed( + number_of_train_samples, + train_dataloader, + number_of_val_samples, + val_dataloader, + number_of_test_samples, + test_dataloader, + epochs, + early_stopping_epochs, + logger + ) + + return result + + def train_and_evaluate_preprocessed( + self, + number_of_train_samples, + train_dataloader, + number_of_val_samples, + val_dataloader, + number_of_test_samples, + test_dataloader, + epochs, + early_stopping_epochs, + logger + ): + + result = { + 'train_loss': [], + 'val_loss': [], + 'test_preds': [], + 'test_labels': [] + } + + for epoch in range(epochs): + + # Run train epoch + avg_loss, avg_lr = self.train_single_epoch(number_of_train_samples, train_dataloader) + logger.info(f'Epoch: {epoch}, Train Loss: {avg_loss:.10f}, Avg LR: {avg_lr:.10f}') + result['train_loss'].append(avg_loss) + + # Evaluate + avg_loss, _, _ = self.evaluate_single_epoch(number_of_val_samples, val_dataloader) + logger.info(f'Epoch: {epoch}, Val Loss: {avg_loss:.10f}') + result['val_loss'].append(avg_loss) + + # Predict on test set and save (we should really only do this at the end but need to save the model somehow first) + preds, labels = self.predict(number_of_test_samples, test_dataloader, with_labels=True) + result['test_preds'].append(preds) + result['test_labels'].append(labels) + + # Compute best epoch + best_epoch = np.argmin(result['val_loss']) + best_val_loss = np.min(result["val_loss"]) + epochs_since_best = result['val_loss'][best_epoch:] + + # Early stop if too many epochs have passed since the best epoch (we should also checkpoint the model here) + if len(epochs_since_best) > early_stopping_epochs: + logger.info(f"Stopping at epoch {epoch}. Best val loss of {best_val_loss:.10f} occurred at epoch {best_epoch}.") + return result + + return result + + def predict(self, number_of_test_samples, test_dataloader, with_labels=False): + + # Predict on test data loader + _, preds, labels = self.evaluate_single_epoch(number_of_test_samples, test_dataloader) + + # Return labels if specificed + if with_labels: + return preds, labels + else: + return preds + + def train_single_epoch(self, number_of_train_samples, train_dataloader): + + model = self.nn + model.train() + + total_loss = 0 + # total_micro_f1 = 0 + # total_macro_f1 = 0 + total_lr = 0 + + # Iterate over batches + for step, batch in enumerate(train_dataloader): + + preds, labels = self.evaluate_single_batch(batch, model, self.params['device']) + + # Compute the loss between actual and predicted values + loss = compute_loss(preds, labels) + + # Backward pass to calculate the gradients + loss.backward() + + # Add to total loss + total_loss += loss.detach().cpu().numpy() + + # # Accumulate gradients + # step_plus_one = step + 1 + if (step + 1) % self.params['accumulation_steps'] == 0: + + # Update parameters + self.optimizer.step() + self.scheduler.step() + + # Zero the parameter gradients + self.optimizer.zero_grad() + + # Add LR at step + total_lr += self.optimizer.param_groups[0]['lr'] + + # Compute the train loss of the epoch + avg_loss = total_loss / number_of_train_samples + avg_lr = total_lr / number_of_train_samples + + return avg_loss, avg_lr + + def evaluate_single_epoch(self, val_samples, val_dataloader): + + model = self.nn + model.eval() + + total_loss = 0 + preds_total = [] + labels_total = [] + + # Iterate over batches + for step, batch in enumerate(val_dataloader): + + # Deactivate autograd + with torch.no_grad(): + + # Generate predictions + preds, labels = self.evaluate_single_batch(batch, model, self.params['device']) + preds_total.extend(preds) + labels_total.extend(labels) + + # Compute the validation loss between actual and predicted values + loss = compute_loss(preds, labels) + total_loss += loss.detach().cpu().numpy() + + # Compute the evaluation loss of the epoch + preds_total = [x.tolist() for x in preds_total] + labels_total = [x.tolist() for x in labels_total] + avg_loss = total_loss / val_samples + + return avg_loss, preds_total, labels_total + + +def create_dataloader(data, sampler_class, batch_size, collate_fn=None): + + sampler = sampler_class(data) + dataloader = DataLoader( + data, + sampler=sampler, + batch_size=batch_size, + collate_fn=collate_fn) + + return dataloader + +def create_train_dataloader(train_data, batch_size, collate_fn=None): + + train_dataloader = create_dataloader( + train_data, RandomSampler, batch_size, collate_fn) + + return train_dataloader + +def create_val_dataloader(val_data, batch_size, collate_fn=None): + + val_dataloader = create_dataloader( + val_data, SequentialSampler, batch_size, collate_fn) + + return val_dataloader + +def create_test_dataloader(test_data, batch_size, collate_fn=None): + + test_dataloader = create_dataloader( + test_data, SequentialSampler, batch_size, collate_fn) + + return test_dataloader + + +def create_dataloaders(train_data, val_data, batch_size, collate_fn=None): + + train_dataloader = create_train_dataloader( + train_data, batch_size, collate_fn) + val_dataloader = create_val_dataloader(val_data, batch_size, collate_fn) + + return train_dataloader, val_dataloader + + +def compute_loss(preds, labels): + + loss = F.binary_cross_entropy(preds, labels.type_as(preds), reduction='sum') + + return loss + diff --git a/long_roberta/custom_datasets.py b/long_roberta/custom_datasets.py new file mode 100644 index 0000000..6f02aea --- /dev/null +++ b/long_roberta/custom_datasets.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Dataset classes and collate functions +""" + +from torch.utils.data import Dataset + +class TextDataset(Dataset): + """ + Dataset for raw texts with labels + """ + + def __init__(self, texts, labels): + self.texts = texts + self.labels = labels + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + return self.texts[idx], self.labels[idx] + + +class TokenizedDataset(Dataset): + """ + Dataset for tokens with labels + """ + + def __init__(self, tokens, labels): + self.input_ids = tokens['input_ids'] + self.attention_mask = tokens['attention_mask'] + self.labels = labels + + def __len__(self): + return len(self.labels) + + def __getitem__(self, idx): + return self.input_ids[idx], self.attention_mask[idx], self.labels[idx] + + +def collate_fn_pooled_tokens(data): + + input_ids = [data[i][0] for i in range(len(data))] + attention_mask = [data[i][1] for i in range(len(data))] + labels = [data[i][2] for i in range(len(data))] + collated = [input_ids, attention_mask, labels] + + return collated diff --git a/long_roberta/evaluate_models.ipynb b/long_roberta/evaluate_models.ipynb new file mode 100644 index 0000000..0a96939 --- /dev/null +++ b/long_roberta/evaluate_models.ipynb @@ -0,0 +1,390 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5d7b1d18", + "metadata": {}, + "source": [ + "### Evaluate Models\n" + ] + }, + { + "cell_type": "markdown", + "id": "80b1b2c7", + "metadata": {}, + "source": [ + "##### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87dc70f1", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import pickle\n", + "import json\n", + "import multiprocessing\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import torch.nn.functional as F\n", + "from sklearn import metrics\n", + "from datasets import load_from_disk, Dataset\n", + "from transformers import AutoTokenizer" + ] + }, + { + "cell_type": "markdown", + "id": "ed3988f7", + "metadata": {}, + "source": [ + "##### Evaluation Parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a1fc80e", + "metadata": {}, + "outputs": [], + "source": [ + "threshold = 0.5 # currently we don't maximize val f1 to find the threshold... need to grab scores for all the val sets if we do this\n", + "num_std = 1.96\n", + "num_bootstrap = 1000\n", + "line_width = 2\n", + "alpha = 0.2\n", + "font_size = 16\n", + "legend_size = 10\n", + "x_size = 10\n", + "y_size = 10" + ] + }, + { + "cell_type": "markdown", + "id": "1ac2d927-76e0-48ab-8777-48bc70206d07", + "metadata": {}, + "source": [ + "##### Initialize Score, Model, and Color Arrays" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acc793b2-80d0-45ac-9a3a-15dcf8fb53fb", + "metadata": {}, + "outputs": [], + "source": [ + "# Define master lists of labels, scores, names, and colors\n", + "all_y_trues, all_y_scores, all_model_names, all_colors = [], [], [], []" + ] + }, + { + "cell_type": "markdown", + "id": "aef4eaf2-fff5-4f83-8ade-2367a2513aa8", + "metadata": {}, + "source": [ + "##### Load Fine-Tuned Torch LM Results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26d15d8e-81cc-4cd5-a2dc-765325cb5a55", + "metadata": {}, + "outputs": [], + "source": [ + "file_info = [('a', 'b', 'c'), ('x', 'y', 'z') ]\n", + " \n", + " \n", + "for label_file, score_file, model_name in file_info: \n", + " with open(label_file, \"rb\") as f: \n", + " labels = pickle.load(f) \n", + " with open(score_file, \"rb\") as f: \n", + " scores = pickle.load(f)\n", + " \n", + " # In the case of the 2048 model, get the score for the 1 label\n", + " if \"RoBERTa (2048)\" in model_name:\n", + " scores = scores[:,1]\n", + " \n", + " all_model_names.append(model_name) \n", + " all_y_trues.append(labels) \n", + " all_y_scores.append(scores) " + ] + }, + { + "cell_type": "markdown", + "id": "248ce15f-2ece-4a51-ab68-83889c25be80", + "metadata": {}, + "source": [ + "##### Define Recall at Precision Metric" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d34e75b9-4adc-4160-af56-9fc85a0c217c", + "metadata": {}, + "outputs": [], + "source": [ + "def recall_at_precision(scores, labels, target_precision):\n", + " \n", + " # Compute precision-recall curve \n", + " precision, recall, thresholds = metrics.precision_recall_curve(labels, scores) \n", + "\n", + " # Find the highest recall where precision >= target_precision \n", + " max_recall = recall[np.where(precision >= target_precision)].max() \n", + "\n", + " return max_recall " + ] + }, + { + "cell_type": "markdown", + "id": "6baed460", + "metadata": {}, + "source": [ + "##### Define a Function to Print the Mean and Confidence Interval for a Given Metric" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26a6d07e", + "metadata": {}, + "outputs": [], + "source": [ + "def print_mean_ci_of_metric_list(metric_list, metric_name, num_std):\n", + " mean_metric = np.mean(metric_list)\n", + " std_metric = np.std(metric_list)\n", + " metric_low = np.maximum(mean_metric - std_metric * num_std, 0)\n", + " metric_high = np.minimum(mean_metric + std_metric * num_std, 1)\n", + "\n", + " print(\n", + " f\"{metric_name}: {round(mean_metric, 3)} ([{round(metric_low, 3)} - {round(metric_high, 3)}] 95% CI)\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "39681795-9909-4c09-b84d-f229b4663c4c", + "metadata": {}, + "source": [ + "##### Define a Function to Select a Threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93cc4daf-6656-4518-8f68-21ea77ab3161", + "metadata": {}, + "outputs": [], + "source": [ + "def get_threshold_of_best_val_f1(val_scores, val_labels):\n", + " \n", + " # Find the best threshold by maximizing F1 score\n", + " print(\" Computing best threshold for F1 on validation set...\")\n", + " best_val_f1 = 0\n", + " best_threshold = 0\n", + " for int_threshold in range(0, 100, 1):\n", + " threshold = int_threshold / 100\n", + " sample_preds = [1 if x >= threshold else 0 for x in val_probs]\n", + " f1 = metrics.f1_score(y_true=val_labels, y_pred=sample_preds)\n", + " if f1 > best_val_f1:\n", + " print(f\" Found new best F1 {f1:.4f} at threshold {threshold}\")\n", + " best_val_f1 = f1\n", + " best_threshold = threshold\n", + " \n", + " return best_threshold" + ] + }, + { + "cell_type": "markdown", + "id": "6f21cd42", + "metadata": {}, + "source": [ + "##### Print Performance for all Metrics for all Models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d6ff4a8", + "metadata": {}, + "outputs": [], + "source": [ + "mean_fpr_linspace = np.linspace(0, 1, 100)\n", + "mean_recall_linspace = np.linspace(0, 1, 100)\n", + "\n", + "model2metric_df = {}\n", + "for y_trues, y_scores, name in zip(\n", + " all_y_trues, all_y_scores, all_model_names\n", + "):\n", + " accuracies, recalls, precisions, aps, interp_ps, roc_aucs, interp_tprs, f1s, rs_at_p90, static_fprs, static_tprs = [], [], [], [], [], [], [], [], [], [], []\n", + " for i in range(num_bootstrap):\n", + " \n", + " # Sample N records with replacement where N is the total number of records\n", + " sample_indices = np.random.choice(len(y_trues), len(y_trues))\n", + " sample_labels = np.array(y_trues)[sample_indices]\n", + " sample_scores = np.array(y_scores)[sample_indices]\n", + " \n", + " # Generate thresholded prediction\n", + " # threshold = get_threshold_of_best_val_f1(val_scores=y_val_scores, val_labels=y_val_trues)\n", + " sample_preds = [1 if x >= threshold else 0 for x in sample_scores]\n", + "\n", + " accuracy = metrics.accuracy_score(y_true=sample_labels, y_pred=sample_preds)\n", + " accuracies.append(accuracy)\n", + " \n", + "# recall = metrics.recall_score(y_true=sample_labels, y_pred=sample_preds)\n", + "# recalls.append(recall)\n", + "\n", + "# precision = metrics.precision_score(y_true=sample_labels, y_pred=sample_preds)\n", + "# precisions.append(precision)\n", + " \n", + "# f1 = metrics.f1_score(y_true=sample_labels, y_pred=sample_preds)\n", + "# f1s.append(f1)\n", + " \n", + " ap = metrics.average_precision_score(y_true=sample_labels, y_score=sample_scores)\n", + " aps.append(ap)\n", + " \n", + " p, r, thresholds = metrics.precision_recall_curve(y_true=sample_labels, probas_pred=sample_scores)\n", + " interp_p = np.interp(mean_recall_linspace, np.fliplr([r])[0], np.fliplr([p])[0])\n", + " interp_ps.append(interp_p)\n", + " \n", + " roc_auc = metrics.roc_auc_score(y_true=sample_labels, y_score=sample_scores)\n", + " roc_aucs.append(roc_auc)\n", + " \n", + " fpr, tpr, _ = metrics.roc_curve(y_true=sample_labels, y_score=sample_scores)\n", + " \n", + " if 'GPT-4' in name or 'Text Gen' in name:\n", + " static_fprs.append(fpr[1])\n", + " static_tprs.append(tpr[1])\n", + " else:\n", + " static_fprs.append(None)\n", + " static_tprs.append(None)\n", + " \n", + " interp_tpr = np.interp(mean_fpr_linspace, fpr, tpr)\n", + " interp_tpr[0] = 0.0\n", + " interp_tprs.append(interp_tpr)\n", + " \n", + " r_at_p90 = recall_at_precision(scores=sample_scores, labels=sample_labels, target_precision=0.9)\n", + " rs_at_p90.append(r_at_p90)\n", + "\n", + " # \"recalls\": recalls,\n", + " # \"precisions\": precisions,\n", + " # \"f1s\": f1s,\n", + " \n", + " metric_df = pd.DataFrame({\n", + " \"aps\": aps,\n", + " \"roc_aucs\": roc_aucs,\n", + " })\n", + " model2metric_df[name] = metric_df\n", + "\n", + " print(f\"\\nResults for {name}\\n\")\n", + " # print_mean_ci_of_metric_list(recalls, metric_name=\"Recall\", num_std=num_std)\n", + " # print_mean_ci_of_metric_list(precisions, metric_name=\"Precision\", num_std=num_std)\n", + " # print_mean_ci_of_metric_list(f1s, metric_name=\"F1\", num_std=num_std)\n", + " print_mean_ci_of_metric_list(aps, metric_name=\"Average Precision\", num_std=num_std)\n", + " print_mean_ci_of_metric_list(roc_aucs, metric_name=\"ROC AUC\", num_std=num_std)\n", + " \n", + "with open(f\"./model2metric_df.pkl\", \"wb\") as f:\n", + " pickle.dump(model2metric_df, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1eb8e9ab-ace6-458b-a596-eb503b3dc8f1", + "metadata": {}, + "outputs": [], + "source": [ + "model2metric_df = {k: v for k, v in model2metric_df.items() if 'Max' not in k}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cf652e2-a0a0-435e-9e0d-80f04c6e1e17", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_mean_with_95_ci(ax, data, metric, condition): \n", + " \n", + " metric_dict = {'aps': 'PR AUC', 'roc_aucs': 'ROC AUC'} \n", + " filtered_data = {k: v for k, v in data.items() if condition in k} \n", + " \n", + " means = [] \n", + " errors = [] \n", + " colors = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black'] \n", + " for model, df in filtered_data.items(): \n", + " mean = df[metric].mean() \n", + " std = df[metric].std() \n", + " ci = 1.96 * std \n", + " \n", + " means.append(mean) \n", + " errors.append(ci) \n", + " \n", + " y_pos = np.arange(len(filtered_data)) \n", + " \n", + " for i, model in enumerate(filtered_data.keys()): \n", + " ax.barh(y_pos[i], means[i], xerr=errors[i], color=colors[i], capsize=10, label=f'M{i}: {map_model_name(model)}') \n", + " \n", + " ax.set_yticks(y_pos) \n", + " ax.set_yticklabels(['M' + str(i) for i in range(len(filtered_data))]) \n", + " ax.set_xlabel(metric_dict[metric]) \n", + " ax.set_title(f'{metric_dict[metric]} for {condition} Prediction') \n", + "\n", + "conditions = ['x', 'y', 'z'] \n", + "metrics = ['aps', 'roc_aucs'] \n", + " \n", + "fig, axs = plt.subplots(3, 2, figsize=(10, 12)) \n", + " \n", + "for i, condition in enumerate(conditions): \n", + " for j, metric in enumerate(metrics): \n", + " plot_mean_with_95_ci(axs[i][j], model2metric_df, metric, condition) \n", + " \n", + "# Add a single legend for the entire plot \n", + "handles, labels = axs[0][0].get_legend_handles_labels() \n", + "fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), \n", + " ncol=len(handles), fancybox=True, shadow=True) \n", + "\n", + "# Add a single title for the entire plot \n", + "fig.suptitle(\"Test Set Performance (1,000 Bootstrap Iterations)\", fontsize=14, y=1.07) \n", + " \n", + "plt.tight_layout() \n", + "plt.subplots_adjust(top=0.99) \n", + "plt.show() " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "619e6c0b-1aa3-4b90-bc1b-103b69d423c4", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10 - SDK v2", + "language": "python", + "name": "python310-sdkv2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/long_roberta/main.py b/long_roberta/main.py new file mode 100644 index 0000000..312d24f --- /dev/null +++ b/long_roberta/main.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Main BERT classes and functions +""" + +import os +import yaml +import torch +import torch.nn as nn +import numpy as np +import pandas as pd +from dataclasses import dataclass +from transformers import PreTrainedTokenizerFast, AutoModel, AdamW +from transformers import AutoTokenizer, AutoModel, BertTokenizer, BertModel +from transformers.optimization import get_linear_schedule_with_warmup +from architecture import BERTSequenceClassificationArch +from base_model import Model +from custom_datasets import TokenizedDataset, collate_fn_pooled_tokens +from text_preprocessors import BERTTokenizer, BERTTokenizerPooled + +class BERTClassificationModel(Model): + def __init__(self): + super().__init__() + + with open("params.yml", "r") as stream: + params = yaml.safe_load(stream) + + self.params = params + tokenizer, bert = load_pretrained_model() + self.preprocessor = BERTTokenizer(tokenizer) + self.dataset_class = TokenizedDataset + self.nn = initialize_model(bert, self.params['device']) + self.optimizer = AdamW( + self.nn.parameters(), + lr=self.params['learning_rate'], + betas=(self.params['adam_beta1'], self.params['adam_beta2']), + weight_decay=self.params['weight_decay'], + eps=self.params['adam_epsilon'] + ) + self.scheduler = get_linear_schedule_with_warmup( + self.optimizer, + num_warmup_steps=self.params['warmup_steps'], + num_training_steps=1000000000000 + ) + + def evaluate_single_batch(self, batch, model, device): + + # Push the batch to gpu + batch = [t.to(device) for t in batch] + + # Predict + model_input = batch[:-1] + labels = batch[-1] + preds = model(*model_input).cpu() + labels = labels.float().cpu() + + return preds, labels + + +class BERTClassificationModelWithPooling(Model): + def __init__(self): + super().__init__() + + with open("params.yml", "r") as stream: + params = yaml.safe_load(stream) + + self.params = params + tokenizer, bert = load_pretrained_model() + self.preprocessor = BERTTokenizerPooled( + tokenizer, params['size'], params['step'], params['minimal_length'], params['max_num_segments'] + ) + self.dataset_class = TokenizedDataset + self.collate_fn = collate_fn_pooled_tokens + self.nn = initialize_model(bert, self.params['device']) + self.optimizer = AdamW( + self.nn.parameters(), + lr=self.params['learning_rate'], + betas=(self.params['adam_beta1'], self.params['adam_beta2']), + weight_decay=self.params['weight_decay'], + eps=self.params['adam_epsilon'] + ) + self.scheduler = get_linear_schedule_with_warmup( + self.optimizer, + num_warmup_steps=self.params['warmup_steps'], + num_training_steps=1000000000000 + ) + + def evaluate_single_batch(self, batch, model, device): + + # Extract elements from batch + input_ids = batch[0] + attention_mask = batch[1] + number_of_chunks = [len(x) for x in input_ids] + labels = batch[2] + + # Concatenate all input_ids into one batch + input_ids_combined = [] + for x in input_ids: + input_ids_combined.extend(x.tolist()) + + input_ids_combined_tensors = torch.stack( + [torch.tensor(x).to(device) for x in input_ids_combined]) + + # Concatenate all attention maska into one batch + attention_mask_combined = [] + for x in attention_mask: + attention_mask_combined.extend(x.tolist()) + attention_mask_combined_tensors = torch.stack( + [torch.tensor(x).to(device) for x in attention_mask_combined]) + + # Get model predictions for the combined batch + preds = model( + input_ids_combined_tensors, + attention_mask_combined_tensors + ) + + # Move predictions to CPU + preds = preds.cpu() + + if self.params['num_labels'] > 1: + + # Split result preds into chunks + preds_split = torch.split(preds, number_of_chunks) + + # Pooling - torch.max return tuples where the first element is the aggregate value + if self.params['pooling_strategy'] == 'mean': + pooled_preds = torch.stack([torch.mean(x, dim=0) for x in preds_split]) + elif self.params['pooling_strategy'] == 'max': + pooled_preds = torch.stack([torch.max(x, dim=0)[0] for x in preds_split]) + elif self.params['pooling_strategy'] == 'custom_agg': + c = self.params['custom_agg_c'] + pooled_preds = torch.stack([ + (torch.max(x, dim=0)[0] + torch.mean(x, dim=0) * number_of_chunks[i]/c) / (1 + number_of_chunks[i]/c) for i, x in enumerate(preds_split) + ]) + else: + raise ValueError(f"Expected pooling strategy to be one of ['mean', 'max', 'custom_agg'] but got {self.params['pooling_strategy']}.") + + else: + + # Flatten preds + preds = preds.flatten() + + # Split result preds into chunks + preds_split = torch.split(preds, number_of_chunks) + + # Pooling - torch.max return tuples where the first element is the aggregate value + if self.params['pooling_strategy'] == 'mean': + pooled_preds = torch.stack([torch.mean(x).reshape(1) for x in preds_split]) + elif self.params['pooling_strategy'] == 'max': + pooled_preds = torch.stack([torch.max(x).reshape(1) for x in preds_split]) + else: + raise ValueError(f"Expected pooling strategy to be one of ['mean', 'max'] but got {self.params['pooling_strategy']}.") + + # Move labels to CPU + labels_detached = torch.tensor(labels).float() + + return pooled_preds, labels_detached + +def load_pretrained_model(): + + tokenizer = load_tokenizer() + model = load_bert() + + return tokenizer, model + +def load_tokenizer(): + + with open("params.yml", "r") as stream: + params = yaml.safe_load(stream) + + tokenizer = AutoTokenizer.from_pretrained(params['tokenizer_path']) + + return tokenizer + +def load_bert(): + + with open("params.yml", "r") as stream: + params = yaml.safe_load(stream) + + model = AutoModel.from_pretrained( + params['bert_path'], + num_labels=params['num_labels'], + return_dict=True + ) + + return model + +def initialize_model(bert, device): + + model = BERTSequenceClassificationArch(bert) + model = model.to(device) + model = nn.DataParallel(model) + + return model diff --git a/long_roberta/metrics.py b/long_roberta/metrics.py new file mode 100644 index 0000000..a602fad --- /dev/null +++ b/long_roberta/metrics.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Bootstrapped multi-label metrics functions for test set evaluation and +function to compute metrics during model training. +""" + +import torch +import numpy as np +import torch.nn.functional as F +from scipy import interp +from sklearn.metrics import ( + precision_recall_curve, + average_precision_score, + auc, + roc_curve, + f1_score, +) + +class BootstrapMultiLabelMetrics(object): + """ + Class containing methods for evaluating performance + of multi-label classifiers by bootstrapping the test set. + + :param labels: 2d numpy array of true labels + :type labels: :class:`numpy.ndarray` + :param preds: 2d numpy array of predicted probabilities for each label + :type preds: :class:`numpy.ndarray` + """ + + def __init__(self, labels, preds): + + self.labels = labels + self.preds = preds + + def assert_2d_array(self): + """ + Check that labels and preds are 2d numpy arrays. + """ + + assert_msg = "Make sure labels and preds are 2d numpy arrays. Use np.stack(array) if passing an array of arrays." + assert len(self.labels.shape) == len(self.preds.shape) == 2, assert_msg + + def get_bootstrapped_average_precision(self, n_bootstrap=1000): + """ + Bootstrap sample the predictions and labels to + compute micro and macro average precisions across all + labels with the average and standard deviation of + these values across all boostrap iterations. + + :return: micro_average_precision_mean_stdv, macro_average_precision_mean_stdv + :rtype: (dict, dict) + """ + + # Ensure labels and preds are 2d arrays + self.assert_2d_array() + + # Run bootstrap iterations + micro_average_precision_mean_stdv, macro_average_precision_mean_stdv = {}, {} + micro_average_precisions, macro_average_precisions = [], [] + for i in range(n_bootstrap): + + # Sample N records with replacement where N is the total number of records + sample_indices = np.random.choice(len(self.labels), len(self.labels)) + sample_labels = self.labels[sample_indices] + sample_preds = self.preds[sample_indices] + + micro_average_precision = average_precision_score( + sample_labels, sample_preds, average="micro" + ) + micro_average_precisions.append(micro_average_precision) + + macro_average_precision = average_precision_score( + sample_labels, sample_preds, average="macro" + ) + macro_average_precisions.append(macro_average_precision) + + # Compute means and stdvs + micro_average_precision_mean_stdv["mean"] = np.mean(micro_average_precisions) + micro_average_precision_mean_stdv["stdv"] = np.std(micro_average_precisions) + macro_average_precision_mean_stdv["mean"] = np.mean(macro_average_precisions) + macro_average_precision_mean_stdv["stdv"] = np.std(macro_average_precisions) + + return micro_average_precision_mean_stdv, macro_average_precision_mean_stdv + + def get_bootstrapped_roc_auc(self, n_bootstrap=1000): + """ + Bootstrap sample the predictions and labels to + compute micro and macro ROC AUC across all + labels with the average and standard deviation of + these values across all boostrap iterations. + + :return: micro_roc_auc_mean_stdv, macro_roc_auc_mean_stdv + :rtype: (dict, dict) + """ + + # Ensure labels and preds are 2d arrays + self.assert_2d_array() + + # Get number of classes + n_classes = self.labels.shape[1] + + # Run bootstrap iterations + micro_roc_auc_mean_stdv, macro_roc_auc_mean_stdv = {}, {} + micro_roc_aucs, macro_roc_aucs = [], [] + for i in range(n_bootstrap): + + # Sample N records with replacement where N is the total number of records + sample_indices = np.random.choice(len(self.labels), len(self.labels)) + sample_labels = self.labels[sample_indices] + sample_preds = self.preds[sample_indices] + + # Compute micro average ROC AUC + fpr_micro, tpr_micro, _ = roc_curve( + sample_labels.ravel(), sample_preds.ravel() + ) + micro_roc_auc = auc(fpr_micro, tpr_micro) + micro_roc_aucs.append(micro_roc_auc) + + # Compute fpr, tpr for each class + fpr, tpr = {}, {} + for i in range(n_classes): + fpr[i], tpr[i], _ = roc_curve(sample_labels[:, i], sample_preds[:, i]) + + # Compute macro-average ROC AUC using fprs and tprs + all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) + mean_tpr = np.zeros_like(all_fpr) + for i in range(n_classes): + mean_tpr += interp(all_fpr, fpr[i], tpr[i]) + mean_tpr /= n_classes + macro_roc_auc = auc(all_fpr, mean_tpr) + macro_roc_aucs.append(macro_roc_auc) + + # Compute means and stdvs + micro_roc_auc_mean_stdv["mean"] = np.mean(micro_roc_aucs) + micro_roc_auc_mean_stdv["stdv"] = np.std(micro_roc_aucs) + macro_roc_auc_mean_stdv["mean"] = np.mean(macro_roc_aucs) + macro_roc_auc_mean_stdv["stdv"] = np.std(macro_roc_aucs) + + return micro_roc_auc_mean_stdv, macro_roc_auc_mean_stdv + + def get_bootstrapped_f1(self, n_bootstrap=1000): + """ + Bootstrap sample the predictions and labels to + compute micro and macro F1 across all + labels with the average and standard deviation of + these values across all boostrap iterations. + + :return: micro_f1_mean_stdv, macro_f1_mean_stdv + :rtype: (dict, dict) + """ + + # Ensure labels and preds are 2d arrays + self.assert_2d_array() + + # Get number of classes + n_classes = self.labels.shape[1] + + # Run bootstrap iterations + threshold = 0.5 + micro_f1_mean_stdv, macro_f1_mean_stdv = {}, {} + micro_f1s, macro_f1s = [], [] + for i in range(n_bootstrap): + + # Sample N records with replacement where N is the total number of records + sample_indices = np.random.choice(len(self.labels), len(self.labels)) + sample_labels = self.labels[sample_indices] + sample_preds = self.preds[sample_indices] + + # Compute f1s + preds_at_threshold = np.array((sample_preds >= threshold), dtype=int) + micro_f1 = f1_score(sample_labels, preds_at_threshold, average="micro") + micro_f1s.append(micro_f1) + macro_f1 = f1_score(sample_labels, preds_at_threshold, average="macro") + macro_f1s.append(macro_f1) + + # Compute means and stdvs + micro_f1_mean_stdv["mean"] = np.mean(micro_f1s) + micro_f1_mean_stdv["stdv"] = np.std(micro_f1s) + macro_f1_mean_stdv["mean"] = np.mean(macro_f1s) + macro_f1_mean_stdv["stdv"] = np.std(macro_f1s) + + return micro_f1_mean_stdv, macro_f1_mean_stdv + + def get_all_bootstrapped_metrics_as_dict(self, n_bootstrap=1000): + """ + Returns all bootstrapped metrics in a nice dictionary. + :return: metrics_dict + :rtype: dict + """ + + micro_average_precision_mean_stdv, macro_average_precision_mean_stdv = self.get_bootstrapped_average_precision(n_bootstrap=n_bootstrap) + micro_roc_auc_mean_stdv, macro_roc_auc_mean_stdv = self.get_bootstrapped_roc_auc(n_bootstrap=n_bootstrap) + micro_f1_mean_stdv, macro_f1_mean_stdv = self.get_bootstrapped_f1(n_bootstrap=n_bootstrap) + + metrics_dict = {} + metrics_dict["micro_ap_mean"] = micro_average_precision_mean_stdv["mean"] + metrics_dict["micro_ap_stdv"] = micro_average_precision_mean_stdv["stdv"] + metrics_dict["macro_ap_mean"] = macro_average_precision_mean_stdv["mean"] + metrics_dict["macro_ap_stdv"] = macro_average_precision_mean_stdv["stdv"] + metrics_dict["micro_roc_auc_mean"] = micro_roc_auc_mean_stdv["mean"] + metrics_dict["micro_roc_auc_stdv"] = micro_roc_auc_mean_stdv["stdv"] + metrics_dict["macro_roc_auc_mean"] = macro_roc_auc_mean_stdv["mean"] + metrics_dict["macro_roc_auc_stdv"] = macro_roc_auc_mean_stdv["stdv"] + metrics_dict["micro_f1_mean"] = micro_f1_mean_stdv["mean"] + metrics_dict["micro_f1_stdv"] = micro_f1_mean_stdv["stdv"] + metrics_dict["macro_f1_mean"] = macro_f1_mean_stdv["mean"] + metrics_dict["macro_f1_stdv"] = macro_f1_mean_stdv["stdv"] + + return metrics_dict + +def compute_training_metrics(pred, threshold=0.5): + """ + Returns dictionary of metrics computed during training + :return: training_metrics + :rtype: dict + """ + + # Compute f1s + labels = pred.label_ids + preds_at_threshold = np.array((pred.predictions >= threshold), dtype=int) + micro_f1 = f1_score(labels, preds_at_threshold, average="micro") + macro_f1 = f1_score(labels, preds_at_threshold, average="macro") + + # Compute loss + y_prob = torch.tensor(pred.predictions) + y_true = torch.tensor(labels).type_as(y_prob) + loss = F.binary_cross_entropy_with_logits(y_prob, y_true).numpy().item() + + # Build metrics dict + training_metrics = { + 'micro_f1': micro_f1, + 'macro_f1': macro_f1, + 'loss': loss + } + + return training_metrics diff --git a/long_roberta/params.yml b/long_roberta/params.yml new file mode 100644 index 0000000..86f7dde --- /dev/null +++ b/long_roberta/params.yml @@ -0,0 +1,49 @@ +# Data +'dataset_path': 'text_label.hf' + +# Pretrained LM +'tokenizer_path': 'roberta-base' +'bert_path': 'roberta-base' + +# Output LM +'output_path': 'roberta_base_text_only_mean/' + +# Model Name +'model_name': 'roberta_base_text_only_mean' + +# Load from file +'model_load_from_file': False + +# Pooled BERT Parameters +'use_pooled_bert': True +'pooling_strategy': 'mean' # one of ['mean', 'max', 'custom_agg'] +'custom_agg_c': 2 +'size': 510 +'step': 100 +'minimal_length': 1 +'max_num_segments': 5 # Each segment adds another (size - step) bits of information. For 2048 seq len: 5 * (510 - 100) = 2050 + +# Linear layer dim +'linear_dim': 768 #1024 + +# Training Parameters +'epochs': 10000000 +'early_stopping_epochs': 5 +'batch_size': 8 # Warning, from a memory consumption perspective, batches are ragged. This is the min # of chunks used in a forward pass. +'accumulation_steps': 16 # Because of the above, we don't exactly know the effective batch size. +'learning_rate': 0.00005 +'num_labels': 1 +'adam_beta1': 0.9 +'adam_beta2': 0.999 +'adam_epsilon': 0.00000001 +'warmup_steps': 100 +'weight_decay': 0.01 +'seed': 1111 + +# Devices +'device': 'cuda' +'visible_gpus': "0" #"0,1,2,3,4,5,6,7" + +# Test data +test_with_imdb_data: False +imdb_data: 'sample_data/imdb_kaggle.csv' diff --git a/long_roberta/params_example.yml b/long_roberta/params_example.yml new file mode 100644 index 0000000..a43887c --- /dev/null +++ b/long_roberta/params_example.yml @@ -0,0 +1,49 @@ +# Data +'dataset_path': 'text_label.hf' + +# Pretrained LM +'tokenizer_path': 'roberta_512/' +'bert_path': 'checkpoint-500000/' + +# Output LM +'output_path': 'text_only_mean/' + +# Model Name +'model_name': 'text_only_mean' + +# Load from file +'model_load_from_file': False + +# Pooled BERT Parameters +'use_pooled_bert': True +'pooling_strategy': 'mean' # one of ['mean', 'max', 'custom_agg'] +'custom_agg_c': 2 +'size': 510 +'step': 100 +'minimal_length': 1 +'max_num_segments': 5 # Each segment adds another (size - step) bits of information. For 2048 seq len: 5 * (510 - 100) = 2050 + +# Linear layer dim +'linear_dim': 768 #1024 + +# Training Parameters +'epochs': 10000000 +'early_stopping_epochs': 5 +'batch_size': 8 # Warning, from a memory consumption perspective, batches are ragged. This is the min # of chunks used in a forward pass. +'accumulation_steps': 16 # Because of the above, we don't exactly know the effective batch size. +'learning_rate': 0.00005 +'num_labels': 1 +'adam_beta1': 0.9 +'adam_beta2': 0.999 +'adam_epsilon': 0.00000001 +'warmup_steps': 100 +'weight_decay': 0.01 +'seed': 1111 + +# Devices +'device': 'cuda' +'visible_gpus': "0" #"0,1,2,3,4,5,6,7" + +# Test data +test_with_imdb_data: False +imdb_data: 'sample_data/imdb_kaggle.csv' diff --git a/long_roberta/pooling.py b/long_roberta/pooling.py new file mode 100644 index 0000000..d3f36a8 --- /dev/null +++ b/long_roberta/pooling.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Functions for preparing input for longer texts - based on +https://www.kdnuggets.com/2021/04/apply-transformers-any-length-text.html +""" + +import torch + + +def tokenize_all_text(text, tokenizer): + """ + Tokenizes the entire text without truncation and without special tokens + + Parameters: + text - single str with arbitrary length + tokenizer - object of class transformers.PreTrainedTokenizerFast + + Returns: + tokens - dictionary of the form + { + 'input_ids' : [...] + 'token_type_ids' : [...] + 'attention_mask' : [...] + } + """ + + tokens = tokenizer.encode_plus(text, add_special_tokens=False, + return_tensors='pt') + return tokens + + +def split_overlapping(array, size, step, minimal_length): + """ + Helper function for dividing arrays into overlapping chunks + """ + + result = [array[i:i + size] for i in range(0, len(array), step)] + if len(result) > 1: + + # Ignore chunks with less then minimal_length number of tokens + result = [x for x in result if len(x) >= minimal_length] + + return result + +def split_tokens_into_smaller_chunks(tokens, size, step, minimal_length): + """ + Splits tokens into overlapping chunks with given size and step + """ + + assert size <= 510 + input_id_chunks = split_overlapping( + tokens['input_ids'][0], size, step, minimal_length) + mask_chunks = split_overlapping( + tokens['attention_mask'][0], size, step, minimal_length) + + return input_id_chunks, mask_chunks + +def add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks, tokenizer): + """ + Adds special CLS token at the beginning and SEP token at the end of each chunk + """ + + for i in range(len(input_id_chunks)): + input_id_chunks[i] = torch.cat( + [torch.Tensor([tokenizer.cls_token_id]), input_id_chunks[i], torch.Tensor([tokenizer.mask_token_id])]) + mask_chunks[i] = torch.cat( + [torch.Tensor([1]), mask_chunks[i], torch.Tensor([1])]) + +def add_padding_tokens(input_id_chunks, mask_chunks, tokenizer): + """ + Adds padding tokens at the end to make sure that all chunks have exactly 512 tokens + """ + + for i in range(len(input_id_chunks)): + + # get required padding length + pad_len = 512 - input_id_chunks[i].shape[0] + + # check if tensor length satisfies required chunk size + if pad_len > 0: + + # if padding length is more than 0, we must add padding + input_id_chunks[i] = torch.cat([ + input_id_chunks[i], torch.Tensor([tokenizer.pad_token_id] * pad_len) + ]) + mask_chunks[i] = torch.cat([ + mask_chunks[i], torch.Tensor([tokenizer.pad_token_id] * pad_len) + ]) + +def stack_tokens_from_all_chunks(input_id_chunks, mask_chunks): + """ + Reshapes data to a form compatible with BERT model input + """ + + input_ids = torch.stack(input_id_chunks) + attention_mask = torch.stack(mask_chunks) + + return input_ids.long(), attention_mask.int() + + +def transform_text_to_model_input( + text, + tokenizer, + size=510, + step=510, + minimal_length=100): + """ + Transforms the entire text to model input of BERT model + """ + + tokens = tokenize_all_text(text, tokenizer) + input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks( + tokens, size, step, minimal_length) + add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks, tokenizer) + add_padding_tokens(input_id_chunks, mask_chunks, tokenizer) + input_ids, attention_mask = stack_tokens_from_all_chunks(input_id_chunks, mask_chunks) + + return [input_ids, attention_mask] diff --git a/long_roberta/prepare_datasets.ipynb b/long_roberta/prepare_datasets.ipynb new file mode 100644 index 0000000..52a100f --- /dev/null +++ b/long_roberta/prepare_datasets.ipynb @@ -0,0 +1,390 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "19b33b3a-5ac1-4a50-a13f-c5810869628a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/anaconda/envs/azureml_py310_sdkv2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import os\n", + "import re\n", + "import gc\n", + "import multiprocessing\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from datasets import load_from_disk, Dataset, DatasetDict \n", + "from sklearn.model_selection import train_test_split\n", + "from transformers import AutoTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "be7b591c-ceb8-41ab-b2ec-70d88c5b30b2", + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = \"nlp_classification_tasks/\"\n", + "conditions = [\"x\", \"y\", \"z\"]\n", + "time_slice = \"a\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2117c1c8-42ff-4a74-9ea1-936ad9b52089", + "metadata": {}, + "outputs": [], + "source": [ + "val_size = 0.1\n", + "seed = 22\n", + "cols = [\"text\", \"label\"]\n", + "max_seq_len = 8192" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3947b18d-102c-4ab6-af49-4fced0dc9dfb", + "metadata": {}, + "outputs": [], + "source": [ + "bioclinroberta_path = 'RoBERTa-base-PM-M3-Voc-distill-align-hf/'\n", + "tokenizer = AutoTokenizer.from_pretrained(bioclinroberta_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "17b24dda-6a78-45ef-851a-239251ad27b3", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_token_len_distribution(dataset):\n", + "\n", + " plt.hist([len(ids) for ids in dataset['input_ids']], bins=50)\n", + " plt.ylabel(\"Frequency\")\n", + " plt.xlabel(\"Tokens per Patient\")\n", + " plt.title(\"Distribution of Tokens per Patient\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "91c99dec-d620-4816-abab-7dd96f44562e", + "metadata": {}, + "outputs": [], + "source": [ + "def tokenize_text(record, tokenizer, truncate_to):\n", + " \n", + " return {\n", + " 'input_ids': tokenizer(\n", + " record['text'],\n", + " padding=False,\n", + " truncation=True,\n", + " max_length=truncate_to\n", + " )['input_ids']\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "727a5878-b361-45c2-b929-ebedf868e14f", + "metadata": {}, + "outputs": [], + "source": [ + "def tokenize_plot(dataset, tokenizer, truncate_to, batch_size=512, proc_div=2):\n", + " \n", + " num_proc = int(multiprocessing.cpu_count() / proc_div)\n", + " print(f\"Tokenizing with {num_proc} CPU processes...\")\n", + " \n", + " dataset = dataset.map(\n", + " tokenize_text,\n", + " batched=True,\n", + " batch_size=batch_size,\n", + " fn_kwargs={\n", + " \"tokenizer\": tokenizer,\n", + " \"truncate_to\": truncate_to\n", + " },\n", + " num_proc=num_proc\n", + " )\n", + " \n", + " plot_token_len_distribution(dataset)\n", + " dataset.remove_columns('input_ids')\n", + " gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "87f84d55-6aa7-47b0-9b44-f08b0bd17443", + "metadata": {}, + "outputs": [], + "source": [ + "def split_on_sole_pipe(input_string): \n", + " \n", + " return re.split(r'(?" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving the dataset (1/1 shards): 100%|██████████| 73248/73248 [00:03<00:00, 21244.96 examples/s]\n", + "Saving the dataset (1/1 shards): 100%|██████████| 8139/8139 [00:00<00:00, 26572.39 examples/s]\n", + "Saving the dataset (1/1 shards): 100%|██████████| 28898/28898 [00:01<00:00, 22993.47 examples/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing data for frailty...\n", + "Train size: 41978\n", + "Val size: 4665\n", + "Test size: 4483\n", + "Tokenizing with 12 CPU processes...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map (num_proc=12): 100%|██████████| 41978/41978 [00:07<00:00, 5806.34 examples/s] \n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving the dataset (1/1 shards): 100%|██████████| 41978/41978 [00:01<00:00, 26285.94 examples/s]\n", + "Saving the dataset (1/1 shards): 100%|██████████| 4665/4665 [00:00<00:00, 19336.65 examples/s]\n", + "Saving the dataset (1/1 shards): 100%|██████████| 4483/4483 [00:00<00:00, 31687.16 examples/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing data for mortality...\n", + "Train size: 34459\n", + "Val size: 3829\n", + "Test size: 6637\n", + "Tokenizing with 12 CPU processes...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map (num_proc=12): 100%|██████████| 34459/34459 [00:07<00:00, 4543.71 examples/s]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving the dataset (1/1 shards): 100%|██████████| 34459/34459 [00:01<00:00, 20791.68 examples/s] \n", + "Saving the dataset (1/1 shards): 100%|██████████| 3829/3829 [00:00<00:00, 23462.47 examples/s]\n", + "Saving the dataset (1/1 shards): 100%|██████████| 6637/6637 [00:00<00:00, 26327.15 examples/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing data for ckd2345_2022_nlp...\n", + "Train size: 18577\n", + "Val size: 2065\n", + "Test size: 2887\n", + "Tokenizing with 12 CPU processes...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Map (num_proc=12): 100%|██████████| 18577/18577 [00:03<00:00, 5233.59 examples/s] \n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving the dataset (1/1 shards): 100%|██████████| 18577/18577 [00:01<00:00, 11427.14 examples/s]\n", + "Saving the dataset (1/1 shards): 100%|██████████| 2065/2065 [00:00<00:00, 14968.99 examples/s]\n", + "Saving the dataset (1/1 shards): 100%|██████████| 2887/2887 [00:00<00:00, 11546.38 examples/s]\n" + ] + } + ], + "source": [ + "for condition in conditions:\n", + " \n", + " # Load dataset\n", + " print(f\"Processing data for {condition}...\")\n", + " dataset_name = f\"{time_slice}_{condition}_ids_text_feats_concatenated.hf\"\n", + " d = load_from_disk(os.path.join(data_dir, dataset_name))\n", + " \n", + " # Convert to dataframes and create validation split\n", + " train_df = d['train'].to_pandas()\n", + " test_df = d['test'].to_pandas()\n", + " train_df, val_df = train_test_split(train_df, test_size=val_size, random_state=seed, stratify=train_df['label'].values)\n", + " \n", + " # Select text, and label columns\n", + " train_df, val_df, test_df = train_df[cols], val_df[cols], test_df[cols]\n", + " \n", + " # Remove duplicate notes\n", + " train_df['text'] = train_df['text'].apply(remove_duplicates_join_on_sep)\n", + " val_df['text'] = val_df['text'].apply(remove_duplicates_join_on_sep)\n", + " test_df['text'] = test_df['text'].apply(remove_duplicates_join_on_sep)\n", + " \n", + " # Count data split sizes\n", + " print(f\"Train size: {train_df.shape[0]}\")\n", + " print(f\"Val size: {val_df.shape[0]}\")\n", + " print(f\"Test size: {test_df.shape[0]}\")\n", + " \n", + " # Create new HF datasets\n", + " train_dataset = Dataset.from_pandas(train_df) \n", + " val_dataset = Dataset.from_pandas(val_df) \n", + " test_dataset = Dataset.from_pandas(test_df)\n", + " \n", + " # View token length distribution\n", + " tokenize_plot(train_dataset, tokenizer, truncate_to=max_seq_len)\n", + " \n", + " # Create new dataset dict\n", + " dataset_dict = { \n", + " \"train\": train_dataset, \n", + " \"val\": val_dataset, \n", + " \"test\": test_dataset \n", + " } \n", + " dataset_dict = DatasetDict(dataset_dict) \n", + " \n", + " # Save dataset dict\n", + " new_dataset_name = f\"{time_slice}_{condition}_text_label.hf\"\n", + " dataset_dict.save_to_disk(os.path.join(data_dir, new_dataset_name))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc81ceba-a91a-4e63-bac6-2c4d33b46a88", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10 - SDK v2", + "language": "python", + "name": "python310-sdkv2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/long_roberta/requirements.txt b/long_roberta/requirements.txt new file mode 100644 index 0000000..92de230 --- /dev/null +++ b/long_roberta/requirements.txt @@ -0,0 +1,13 @@ +transformers==4.34.0 +tokenizers==0.14.1 +PyYAML +sentencepiece +scikit-learn +pandas +datasets +protobuf +matplotlib +torch==1.13.1 +accelerate +pyarrow +pytest \ No newline at end of file diff --git a/long_roberta/run_exps.py b/long_roberta/run_exps.py new file mode 100644 index 0000000..4639542 --- /dev/null +++ b/long_roberta/run_exps.py @@ -0,0 +1,88 @@ +import os +import subprocess +import yaml + +def update_params_yaml( + params, + dataset_path, + output_path, + model_name, + pooling_strategy, + tokenizer_path, + bert_path + ): + params['dataset_path'] = dataset_path + params['output_path'] = output_path + params['model_name'] = model_name + params['pooling_strategy'] = pooling_strategy + params['tokenizer_path'] = tokenizer_path + params['bert_path'] = bert_path + + with open('params.yml', 'w') as params_file: + yaml.dump(params, params_file) + +def run_train_and_evaluate(): + subprocess.run(['python', 'train_and_evaluate.py']) + +# Load the params.yml file +with open('params.yml', 'r') as params_file: + params = yaml.safe_load(params_file) + +# Set tokenizer_path and bert_path +tokenizer_path = 'roberta_v2/' +bert_path = 'checkpoint-500000/' + +datasets = [ + '_text_label.hf', + '_text_label.hf', + '_text_label.hf', + '_text_label.hf', + '_text_label.hf', + '_text_label.hf' +] +output_paths = [ + '_text_only/', + '_text_only_mean/', + '_text_only/', + '_text_only_mean/', + '_text_only/', + '_text_only_mean/' +] +model_names = [ + "_text_only", + "_text_only_mean", + "_text_only", + "_text_only_mean", + "_text_only", + "_text_only_mean" +] +pooling_strategies = [ + "max", + "mean", + "max", + "mean", + "max", + "mean" +] + +assert len(datasets) == len(output_paths) == len(model_names) == len(pooling_strategies), "Error in param lists." + +# Iterate through the different values and run train_and_evaluate.py +for dataset, output_path, model_name, pooling_strategy in zip(datasets, output_paths, model_names, pooling_strategies): + + # Update params.yml file with new values + update_params_yaml( + params, + dataset, + output_path, + model_name, + pooling_strategy, + tokenizer_path, + bert_path + ) + + # Create output directory if it doesn't exist + os.makedirs(output_path, exist_ok=True) + + # Run the train_and_evaluate.py script + run_train_and_evaluate() diff --git a/long_roberta/text_preprocessors.py b/long_roberta/text_preprocessors.py new file mode 100644 index 0000000..a5ace0c --- /dev/null +++ b/long_roberta/text_preprocessors.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Classes and functions for tokenization and text preprocessing +""" + +import numpy as np +import torch +from pooling import transform_text_to_model_input + + +class Preprocessor(): + ''' + An abstract class for text preprocesssors. Preprocessor takes an array of strings and transforms it to an array of data compatible with the model + ''' + + def __init__(self): + pass + + def preprocess(self, array_of_texts): + raise NotImplementedError( + "Preprocessing is implemented for subclasses only") + + +class BERTTokenizer(Preprocessor): + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def preprocess(self, array_of_texts): + tokens = tokenize(array_of_texts, self.tokenizer) + return tokens + + +class BERTTokenizerPooled(Preprocessor): + def __init__(self, tokenizer, size, step, minimal_length, max_num_segments): + self.tokenizer = tokenizer + self.text_splits_params = [size, step, minimal_length, max_num_segments] + + def preprocess(self, array_of_texts): + array_of_preprocessed_data = tokenize_pooled( + array_of_texts, self.tokenizer, *self.text_splits_params) + return array_of_preprocessed_data + + +def tokenize(texts, tokenizer): + ''' + Transforms list of texts to list of tokens (truncated to 512 tokens) + + Parameters: + texts - list of strings + tokenizer - object of class transformers.PreTrainedTokenizerFast + + Returns: + array_of_preprocessed_data - array of the length len(texts) + ''' + + texts = list(texts) + tokenizer.pad_token = "" + tokens = tokenizer.batch_encode_plus( + texts, + max_length=512, + padding=True, + truncation=True, + return_tensors='pt') + + return tokens + + +def tokenize_pooled(texts, tokenizer, size, step, minimal_length, max_num_segments): + ''' + Tokenizes texts and splits to chunks of 512 tokens + + Parameters: + texts - list of strings + tokenizer - object of class transformers.PreTrainedTokenizerFast + + size - size of text chunk to tokenize (must be <= 510) + step - stride of pooling + minimal_length - minimal length of a text chunk + + Returns: + array_of_preprocessed_data - array of the length len(texts) + ''' + model_inputs = [ + transform_text_to_model_input( + text, + tokenizer, + size, + step, + minimal_length) for text in texts] + + input_ids = [model_input[0][0:max_num_segments] for model_input in model_inputs] + attention_mask = [model_input[1][0:max_num_segments] for model_input in model_inputs] + + tokens = {'input_ids': input_ids, 'attention_mask': attention_mask} + return tokens diff --git a/long_roberta/train_and_evaluate.py b/long_roberta/train_and_evaluate.py new file mode 100644 index 0000000..ed2b560 --- /dev/null +++ b/long_roberta/train_and_evaluate.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# coding: utf-8 + +# Open imports +import os +import json +import yaml +import logging +import pickle +import numpy as np +import pandas as pd +from datasets import load_from_disk + +# Project imports +from main import BERTClassificationModelWithPooling, BERTClassificationModel +from utils import check_empty_count_gpus, create_current_run, np_sigmoid, load_and_split_imdb_data, plot_learning_curve +from metrics import BootstrapMultiLabelMetrics + +# Load run parameters +with open("params.yml", "r") as stream: + PARAMS = yaml.safe_load(stream) + +# Define logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Log parameters +logger.info(PARAMS) + +# Check, empty, and count GPUs +check_empty_count_gpus(logger=logger) + +# Set gpus +os.environ["CUDA_VISIBLE_DEVICES"]= PARAMS['visible_gpus'] + +# Create run directory +current_run_dir = create_current_run(save_path=PARAMS['output_path'], params=PARAMS, logger=logger) +logger.info(f"Created run directory: {current_run_dir}.") + +# Set run name +run_name = current_run_dir.split('/')[-1] +logger.info(f"Starting run {run_name}...") + +# Load data +if PARAMS['test_with_imdb_data']: + + # Use IMDB data to test model + X_train, X_val, X_test, y_train, y_val, y_test = load_and_split_imdb_data(PARAMS['imdb_data'], num_labels=PARAMS['num_labels']) + +else: + + # Load real data + d = load_from_disk(PARAMS['dataset_path']) + X_train = np.array(d['train']['text']) + y_train = np.array(d['train']['label']).reshape(-1, 1) + X_val = np.array(d['val']['text']) + y_val = np.array(d['val']['label']).reshape(-1, 1) + X_test = np.array(d['test']['text']) + y_test = np.array(d['test']['label']).reshape(-1, 1) + +# Print data shapes +logger.info(f'Train shapes: {len(X_train), y_train.shape}') +logger.info(f'Val shapes: {len(X_val), y_val.shape}') +logger.info(f'Test shapes: {len(X_test), y_test.shape}') + +# Load model +if PARAMS['use_pooled_bert']: + model = BERTClassificationModelWithPooling() +else: + model = BERTClassificationModel() + +# Train and evaluate +result = model.train_and_evaluate( + X_train, + X_val, + X_test, + y_train, + y_val, + y_test, + epochs=PARAMS['epochs'], + early_stopping_epochs=PARAMS['early_stopping_epochs'], + logger=logger +) + +# Find best epoch +best_epoch = np.argmin(result['val_loss']) +logger.info(f'Val losses: {result["val_loss"]}.') +logger.info(f'Best val loss: {np.min(result["val_loss"])}.') +logger.info(best_epoch) + +# Get test preds +test_preds = np.array(result['test_preds'][best_epoch]) +test_labels = np.array(result['test_labels'][best_epoch]) + +# Save final preds and labels +with open(f"./{PARAMS['model_name']}_scores.pkl", "wb") as f: + pickle.dump(test_preds, f) +with open(f"./{PARAMS['model_name']}_labels.pkl", "wb") as f: + pickle.dump(test_labels, f) + +# Compute final performance +evaluator = BootstrapMultiLabelMetrics(labels=test_labels, preds=test_preds) +metrics_dict = evaluator.get_all_bootstrapped_metrics_as_dict(n_bootstrap=1000) +logger.info(metrics_dict) + +# Save metrics +with open(current_run_dir + 'metrics.json', "w") as f: + json.dump(metrics_dict, f) +with open(f'./{PARAMS["model_name"]}_metrics.json', "w") as f: + json.dump(metrics_dict, f) + +# Plot learning curves from training +nresult = {k:v for k, v in result.items() if 'test' not in k} +plot_learning_curve(nresult, current_run_dir, prefix=PARAMS['model_name']) +plot_learning_curve(nresult, './', prefix=PARAMS['model_name']) diff --git a/long_roberta/utils.py b/long_roberta/utils.py new file mode 100644 index 0000000..dfbbbc7 --- /dev/null +++ b/long_roberta/utils.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# coding: utf-8 + +""" +Utility functions for deep learning experiments +""" + +import os +import yaml +import torch +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split + + +def create_current_run(save_path, params, logger=None): + """ + Create a directory for the current run, save the + current pipeline parameters, and return the + path to the current run directory. + """ + + # Create current run dir + src_dirs = os.listdir(save_path) + max_run = max([int(dir.split('_')[1]) for dir in src_dirs]) if len(src_dirs) > 0 else -1 + current_run_dir = os.path.join(save_path, 'run_' + str(max_run + 1) + '/') + os.makedirs(current_run_dir) + + if logger: + logger.info(f'Created current run dir: {current_run_dir}.') + + # Save run params in current run dir for reference + with open(os.path.join(current_run_dir, 'params.yml'), 'w') as stream: + yaml.dump(params, stream, default_flow_style=False) + + if logger: + logger.info(f'Saved run parameter to current run dir.') + + return current_run_dir + +def check_empty_count_gpus(logger=None): + """ + Check that GPU is available, empty the cache, + and count the number of available devices. + """ + + # Check that a GPU is available: + assert torch.cuda.is_available(), 'No GPU found. Please run on a GPU.' + + # Empty GPU cache + torch.cuda.empty_cache() + + # Count available devices + device_count = torch.cuda.device_count() + + if logger: + logger.info(f'Found {device_count} GPU(s)!') + +def np_sigmoid(z): + """ + Convert logits to probabilities: + https://en.wikipedia.org/wiki/Sigmoid_function. + """ + + return 1 / (1 + np.exp(-z)) + +def to_binary_one_hot(y): + """ + Convery 0 and 1 labels to + [1, 0] and [0, 1] for generality + """ + + yn = np.zeros((len(y), 2), dtype=int) + for i, val in enumerate(y): + yn[i, 0] = 1 - val # 0 -> 1 & 1 -> 0 + yn[i, 1] = val # 0 -> 0 & 1 -> 1 + + return yn + +def load_and_split_imdb_data(path, seed=42, num_labels=1): + """ + Load and split imdb data for testing code. + """ + + # Read data and create features and labels + df = pd.read_csv(path) + texts = df['sentence'].tolist() + labels = df['target'].tolist() + + # Create train, val, test splits + X_train, X_test, y_train, y_test = train_test_split(texts, labels, test_size=.15, random_state=seed, shuffle=True) + X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=.15, random_state=seed, shuffle=False) + + # Convert binary labels to one hot labels for generality or use a single binary label wrapped in an extra dim + if num_labels == 1: + y_train = np.array([[x] for x in y_train]) + y_val = np.array([[x] for x in y_val]) + y_test = np.array([[x] for x in y_test]) + elif num_labels == 2: + y_train = to_binary_one_hot(y_train) + y_val = to_binary_one_hot(y_val) + y_test = to_binary_one_hot(y_test) + else: + raise ValueError("For this dataset, the labels should be encoded using either 1 or 2 columns.") + + return X_train, X_val, X_test, y_train, y_val, y_test + +def plot_learning_curve(result, current_run_dir, prefix): + + cmap = plt.get_cmap("tab10") + fig, ax = plt.subplots(figsize = (10,10)) + + for i, (key, value) in enumerate(result.items()): + ax.plot(value, '-',label=key,color=cmap(i)) + ax.legend() + + plt.legend(loc='upper right') + plt.xlabel('Epoch') + plt.ylabel('Value') + plt.title('Learning Curves') + plt.tight_layout() + plt.savefig(current_run_dir + f'{prefix}_learning_curves.png', transparent=False)