This repository contains the code and datasets for the study on "Discrete Prompt Compression with Reinforcement Learning" The project aims to explore the practical way to compress prompts in instruction-tuned models with the application of reinforcement learning, thereby improving efficiency and performance.
- Python 3.9
- PyTorch 1.13.1
- Other dependencies listed in
requirements.txt
- Clone the repository:
git clone https://github.com/nenomigami/PromptCompressor.git
- Install the required packages:
cd PromptCompressor pip install -r requirements.txt
The dataset used in this study includes the Alpaca+ from the Mu et al. (2023) repository. We have directly utilized this dataset to conduct our experiments on prompt compression. The Alpaca+ data is located in data/alpaca_plus
.
The training process consists of two main steps:
First, we fine-tune the existing foundation models (gpt2-xl, flan-t5-xl) on the alpaca+ dataset. If you want to apply it to your own project and have a preferred instruction-tuned model, you can proceed with that as well and skip this section.
For gpt2-xl:
python script/finetune_gpt2.py
For flan-t5-xl
python script/finetune_flan-t5-xl.py
After fine-tuning the models, we proceed to train the PCRL using the following command:
python train_pcrl.py --config_path configs/gpt2.yml --log_to_wandb --seed=myseed --experiment_name=my_experiment
The evaluation process involves executing four different scripts in a specific sequence. The dependencies between these scripts are as follows:
evaluate_pcrl.py
requires the results fromevaluate_original.py
.evaluate_chatgpt
requires the results from the other experiments
Run the evaluate_original.py
script with the following arguments:
python scripts/evaluate_original.py --gen_model=gpt2-xl --bs=16 --results_dir=results
Run the evaluate_heuristic.py script with the same arguments as the original
python scripts/evaluate_heuristic.py --gen_model=gpt2-xl --bs=16 --results_dir=results
After obtaining the results from evaluate_original.py, run the evaluate_pcrl.py script with the following arguments:
python scripts/evaluate_pcrl.py --pcrl_model=gpt2-xl --seed=myseed --gen_model=gpt2-xl --bs=16 --results_dir=results
Finally, run the evaluate_chatgpt script with the following arguments:
python scripts/evaluate_chatgpt --gen_model=gpt2-xl --eval_model=gpt2-xl --split=seen --seed=myseed
The codebase is licensed Apache 2.0 (see LICENSE). The data is a mixture of Self-Instruct (Apache 2.0) and Stanford Alpaca (CC BY-NC 4.0). By training on a mixture of the data, it inherits both licenses.
This code references the RL4LMs by AllenAI and gisting by Mu et al.. We express our gratitude to the authors for their contributions to the field and for making their code publicly available.
For any questions or feedback, please contact ghdbsl98@gmail.com.