Adaptive Contrastive Search: Uncertainty-Guided Decoding for Open-Ended Text Generation (ACS)
Official PyTorch Implementation
Paper | Run Adaptive Contrastive Search Demo
This repo contains PyTorch model definitions, sampling codes and evaluation codes for our paper Adaptive Contrastive Search: Uncertainty-Guided Decoding for Open-Ended Text Generation (ACS). You can find more technique details in our paper.
Adaptive Contrastive Search: Uncertainty-Guided Decoding for Open-Ended Text Generation
Esteban Garces Arias, Julian Rodemann, Meimingwei Li, Christian Heumann,Matthias Aßenmacher
Department of Statistics, LMU Munich, Munich Center for Machine Learning (MCML)
In this study, we introduce adaptive contrastive search, a novel decoding strategy extending contrastive search by incorporating an adaptive degeneration penalty, guided by the estimated uncertainty of the model at each generation step. This strategy is designed to enhance both the creativity and diversity of the language modeling process while at the same time producing coherent and high-quality generated text output. Our findings indicate performance enhancement in both aspects, across different model architectures and datasets, underscoring the effectiveness of our method in text generation tasks. Our code base, datasets, and models are publicly available.
This repository contains:
- 🪐 A simple PyTorch implementation of Adaptive Contrastive Search
- ⚡️ Faster Metrics Calculation with Coherence, MAUVE and Diversity
- 💥 A Colab notebook for running Adaptive Contrastive Search Demo in colab
- Environment Setup
- Dataset Information
- Generation Baselines
- Metrics Evaluation (MAUVE, Diversity, Coherence)
- Human Evaluation
Environment Setup 🚀 [Back to Top]
First, download and set up the repo:
git clone https://github.com/YecanLee/Adaptive-Contrastive-Search
cd Adaptive-Contrastive-Search
We provide an environment.yml
file that can be used to create a Conda environment.
conda env create -f environment.yml
conda activate acs
If you prefer not to use Conda, you can install the dependencies using pip.
pip install -r requirements.txt
If you want to use flash-attention
, you may need to install cuda-toolkit
by yourself if it is not already installed.
conda install nvidia/label/cuda-12.2.0::cuda-toolkit
pip install ninja packaging flash-attn
You can find the specific version of cuda-toolkit
that is compatible with your GPU from here.
Dataset Information [Back to Top]
We used two different datasets for performance comparison in the paper. The first dataset is the English dataset in the data
folder, which contains wikitext
, wikinews
and book
files. The second dataset is the multilingual
dataset, which contains 15 different language files.
To check the details for each dataset, please check the data/README.md and multilingual_data/README.md.
Generation Baselines [Back to Top]
**** The following part is the
static contrastive search
baseline model from our paper for performance comparision::****
You can run the story generation baseline with static contrastive search paper method by running static_contrastive_base.py. The script has various arguments to switch between different datasets, adjust penalty alpha and k value, change the generation result saving path, etc. For example, to sample from wikitext
dataset with k=10
and alpha=0.6
, you can use:
python story_generation/static_contrastive_search.py \
--config_path configs/contrastive_base.yaml \
--model_name gpt2-xl \
--save_path_prefix wikitext_CS \
--k 10 \
--alpha 0.6 \
--dataset wikitext
**** The following part is our proposed
adaptive contrastive search
method baseline model from our paper::****
You can run the story generation baseline with our proposed adaptive contrastive search method by running adaptive_contrastive_base.py. This script also has various arguments to switch between different datasets, adjust penalty alpha and k value starting value, change the generation result saving path, etc. For example, to sample from wikitext
dataset with q=8
and alpha=0.6
as starting value, you can run:
python story_generation/adaptive_contrastive_search.py \
--config_path configs/adaptive_contrastive_base.yaml \
--q 8 \
--dataset wikitext \
--save_path_prefix wikitext_ACS \
--k 10
To reproduce the results in the paper, you can change the flag --q
for different experiment settings. Change the --dataset
to one of wikitext, wikinews, book
.
You can also change the flag --k
if you want the ACS
strategy to be initialized from a different k
.
If you only want to test generating with several samples, please change the --data_num
flag into your preferred value. The default setting will use the whole dataset for generation.
Generation could likely be speed-up by:
- using Flash Attention in the generation scripts, please check the Environment Setup section for more details.
- using
torch.compile
in PyTorch 2.0, we implemented this by usingmax_autotune
mode in the generation scripts, you may need to modify thetorch.compile
codes to fit your needs.
TF32 Note (important for Ampere, Hopper, and other recent NVIDIA GPUs users).
When we ran the above generation scripts, TF32 matmuls were disabled per PyTorch's defaults.
We've enabled them at the top of static_contrastive_base.py
and adaptive_contrastive_base.py
because it makes sampling way way way faster on
those GPUs, but note that the use of TF32 may lead to some differences in the results. Those differences are likely to be negligible for most comparison purposes.
Metrics Evaluation (MAUVE, Diversity, Coherence) [Back to Top]
We provide the scripts for calculating the metrics score for MAUVE
, diversity
and coherence
. MAUVE
score is calculated by using the mauve-text package. diversity
score is calculated by using the same n-gram
diversity method mentioned in contrastive search paper. coherence
score is calculated by using method mentioned in coherence paper.
To calculate the metrics score for MAUVE
, diversity
and coherence
, please run the following commands:
cd scripts/
# Compute the coherence score
measure_coherence.sh YOUR_GENERATION_RESULT_PATH
# Compute the MAUVE and diversity
measure_mauve.sh YOUR_GENERATION_RESULT_PATH
TF32 Note (important for Ampere, Hopper, and other recent NVIDIA GPUs users).
Our coherence score calculation script uses TF32 for faster computation, which may lead to some super small differences in the results.
Human Evaluation Result [Back to Top]
To check the human evaluation result reported in the paper, please check the Human_Evaluation_Results.xlsx
file.
@article{garces2024adaptive,
title={Adaptive Contrastive Search: Uncertainty-Guided Decoding for Open-Ended Text Generation},
author={Garces Arias, Esteban and Rodemann, Julian and Li, Meimingwei and Heumann, Christian and A{\ss}enmacher, Matthias},
journal={arXiv e-prints},
pages={arXiv--2407},
year={2024}
}
We wish to express our gratitude to Daniele Pugno and Nicolò Campagnoli for their technical support and visualizations. Matthias Aßenmacher
received funding from the Deutsche Forschungsgemeinschaft, Julian Rodemann acknowledges support by the Federal Statistical Office of Germany as well as by the Bavarian Institute for Digital Transformation (bidt) and the Bavarian Academy of Sciences (BAS).
The codebase borrows from Contrastive_Search_versus_Contrastive_Decoding repository. We also thank the authors for their open-sourcing the codes.
See LICENSE.txt
for details.