- Authors: Dongyue Li, Ziniu Zhang, Lu Wang and Hongyang R. Zhang
- Paper: arXiv
This code implements a fast method for the Estimation of fine-tuning model losses using Gradients (GradEx). Given a list of subsets of tasks, this method can estimate the LM fine-tuning losses on the subsets, without repeated model fine-tuning. It trades off the repeated model fine-tuning with solving logistic regression using gradients as features. It can be applied in subset selection problems to perform task/data selection in fine-tuning language models. We provide the code for experiments of chain-of-thought fine-tuning and instruction fine-tuning.
To build up the environment, please run the following commands.
conda create -n gradex python=3.10
conda activate gradex
pip install -r requirements.txt
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124 # check the correct version for pytorch nightly about CUDA
mkdir ./data
mkdir ./results
mkdir ./external_lightning_logs
python setup.py develop
We provide stand-alone examples for illustrating the use cases of our algorithm:
- Select correct examples from a noisy synthetic dataset:
examples/example_noisy_synthetic_addition_task.ipynb
- Estimate Llama-3-8B model fine-tuning losses:
example_approximate_llama_finetuning_loss.ipynb
Chain-of-thought fine-tuning. Please refer to the reasoning-teacher repository for downloading the chain-of-thought data, including CommonsenseQA and StartegyQA.
Instruction fine-tuning.
-
Alpaca: Please download the data from this link and put the pickle file under the
./data/alpaca_data
folder. -
FLAN v2: Please refer to the open-instruct repository for downloading the FLAN v2 (and COT) instruction fine-tuning data.
After downloading, save the datasets as pickle files under the ./data/alpaca_data
folder. For example:
from datasets import load_dataset
import pandas as pd
flan_dataset = load_dataset("json", data_files="./raw_train/tulu_v1_resampled_flan_100k.jsonl")["train"]
def reformat_flan(example):
prompt = example["inputs"]
if not prompt.endswith("\n") and not prompt.rstrip().endswith(":"):
prompt += "\n"
completion = example["targets"]
example['text'] = prompt + completion
example['skill'] = example['_task_name']
return example
flan_dataset_df = flan_dataset.map(reformat_flan)
pd.to_pickle(flan_dataset_df, "./flan_dataset.pkl")
Our algorithm contains the following steps:
- Meta training: Multitask training on all tasks to obtain a meta-initialization. Then, we evaluate and project the gradients of all training samples on the meta-initialization.
- Estimation: Estimate model fine-tuning performances on a list of task subsets using projected gradients as features in logistic regression.
- Selection: Using the estimated results to conduct subset selection, including forward stepwise selection and random ensemble.
This step fine-tunes a language model on the combination of all tasks.
-
Use
custom_train_cot.py
to fine-tune LMs on the chain-of-thought data. -
Use
custom_train_instruction.py
to fine-tune LMs on the instruction fine-tuning data. Use--train_instruction
to load the FLAN v2 dataset. Without--train_instruction
, it will load the Alpaca dataset.
We provide bash script examples under scripts/meta_training_**.sh
.
Evaluating and projecting gradients on all training samples:
-
For chain-of-thought fine-tuning, use
fast_estimate_compute_gradients_cot.py
. Use--load_model_dir
to specify a saved checkpoint directory as the base model. Specify--project_dim
as the number of projections. -
For instruction fine-tuning, use
fast_estimate_eval_approximation_instruction.py
. Use--train_instruction
to load the FLAN v2 or Alpaca datasets. Use--compute_pretrained_outputs
to compute the gradients. The parameters are similar to the file above.
We provide bash script examples under scripts/fast_estimate_gradients_**.sh
. These files will save the projection matrix and all projected gradients under a ./gradients/
folder. Please create the folder before usage.
We solve linear regression using the gradients collected above as features to estimate the output of the model fine-tuned on a subset of tasks.
-
For chain-of-thought fine-tuning, use
fast_estimate_linear_model_cot.py
. Specify--save_name
for the file to save the evaluation results of estimated models. Specify--number_of_subsets
and--subset_size
to control the number and size of sampled subsets -
For instruction fine-tuning, use
fast_estimate_linear_regression_alpaca.py
. The parameters is similar to the above.
We provide bash script examples under scripts/fast_estimate_logistic_regression_**.sh
. Inside the files, one can modify the subsets collection file under ./sampled_tasks/
to specify the sampled subsets of tasks. Normally, it should be randomly sampled subsets.
- Forward stepwise selection: please refer to
utils/fast_estimate_forward_selection.py
to conduct forward selection to select a subset of data. - Random ensemble: Please refer to
utils/select_random_ensemble.py
for an example of estimating random ensemble scores. Then, we apply a threshold (or can be viewed as the top-k selection) to the scores to select a subset of tasks.
If you find this repository useful or happen to use it in a research paper, please our work with the following bib information
@article{Li2024scalable,
title={Scalable Fine-tuning from Multiple Data Sources: A First-Order Approximation Approach},
author={Li, Dongyue and Zhang, Ziniu and Wang, Lu and Zhang, Hongyang R},
journal={EMNLP Findings},
year={2024},
}