We developed this codebase to study few-shot learning on text classification tasks. More specifically, we seek to understand how language models process prompts containing false demonstrations. Our code makes it possible to:
- Use the logit lens and calibration to decode the model’s intermediate computations
- Compute statistics on attention heads (where they read from and write to)
- Ablate attention heads to analyse how such interventions affect model behavior
We've made it easy to scale the codebase to other datasets, models, and settings.
See our paper for more details.
The code requires a single GPU (see colab for free GPU access). There is no training, and inference is relatively cheap. A GPU with high RAM will be needed, however, to run larger models; for example, an A100 40GB and A100 80GB were used to run GPT-J and GPT-NeoX20B, respectively. If you are limited by memory, consider initializing the model with fp16.
We recommend creating a new anaconda environment to simplify package management:
conda create -n overthinking_the_truth python=3.11
source activate overthinking_the_truth
pip install -r requirements.txt
One can also simply run:
pip install -r requirements.txt
See notebooks/demo
for a walkthrough of our pipeline.
Run the following inside scripts/logit_lens
with a choice of {model(s), dataset(s), setting(s)}:
python logit_lens.py --models {models} --datasets {datasets} --settings {settings} --num_inputs {#inputs} --num_demos {#demos}
Concretely, the following script executes for gpt_j
on the datasets sst2, dbpedia
with permuted incorrect labels
on 1000
inputs and 40
demonstrations per input.
python logit_lens.py --models gpt_j --datasets sst2,dbpedia --settings permuted_incorrect_labels --num_inputs 1000 --num_demos 40
To replicate all our logit_lens results, run the following command.
python logit_lens.py --models all --datasets all --settings all --num_inputs 1000 --num_demos "max"
Run the following inside scripts/attention
with a choice of {model(s), dataset(s)}:
python attention.py --models {models} --datasets {datasets} --num_inputs {#inputs} --num_demos {#demos}
To replicate all our attention results, run the following command.
python attention.py --models all --datasets unnatural --num_inputs 250 --num_demos "max"
Run the following inside scripts/logit_lens
with a choice of {model(s), dataset(s), setting(s)}:
python ablations.py --models {models} --datasets {datasets} --settings {settings} --abl_types {ablation_types} --num_inputs {#inputs} --num_demos {#demos}
To replicate all our ablation results, run the following command.
python ablations.py --models gpt_j --datasets all --settings all --abl_types all --num_inputs 1000 --num_demos max & python ablations.py --models gpt2_xl,gpt_neox --datasets all --settings all --abl_types attention,mlp --num_inputs 1000 --num_demos max
See notebooks/analysis
to analyze the output of the scripts. For example, after running the logit_lens script, the notebooks/analysis/logit_lens
notebook will generate the layerwise plots.
@misc{halawi2023overthinking,
title={Overthinking the Truth: Understanding how Language Models Process False Demonstrations},
author={Danny Halawi and Jean-Stanislas Denain and Jacob Steinhardt},
year={2023},
eprint={2307.09476},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
This code was developed by Danny Halawi and Jean-Stanislas Denain. Questions about the codebase can be directed to us ({dhalawi, js_denain}@berkeley.edu, respectively). If you find an issue, please open an issue. If you'd like to contribute, feel free to open a pull request.