Code for training and evaluating a sparse autoencoder on OthelloGPT.
Contents:
- analysis.py contains many methods for evaluating the sparse autoencoders (see below for details)
- autoencoder.py architecture for a sparse autoencoder, as in Cunningham et al
- linear_probes.py architecture for a linear probe on the residual stream of a language model, for classifying board states
- model_training.py control code for training OthelloGPT/probes/autoencoder (see below for details)
- othello_gpt.py architecture for an OthelloGPT model, as in Li et al
- requirements.txt requirements. Install with
pip install -r requirements.txt
- train.py code that executes a training run on OthelloGPT/probes/autoencoder
- utils/game_engine.py plays Othello
- utils/generate_training_corpus.py generates a corpus of Othello games, for training or assessment (see below for details)
- Some other auxilary methods in utils/
How to use this code:
- Download datasets and trained models from https://drive.google.com/drive/folders/1xMkEctaqAUjoPXGY-9dBu-pE3SJjKx2K
- If you did not download the trained sparse autoencoder, or want to train your own, run model_training.py's
full_sae_training()
method. The OthelloGPT model must have anintermediate_residual_stream()
method (see othello_gpt.py for details). - If you did not download the cached AUROCs, or want to make new ones, run analysis.py's
evaluate_all_probe_classification()
,evaluate_all_legal_moves_classification()
, andevaluate_all_content_classification()
. - To print the best AUROCs, run analysis.py's
find_top_aurocs_legal()
orfind_top_aurocs_contents
. - To create density plots of the features as classifiers for positions, run analysis.py's
create_density_plot_contents(feature_number=N, board_position=M)
orcreate_density_plot_legal(feature_number=N, board_position=M)
. Feature numbers N run 0-1023, board positions M run 0-63. The values of N and M correspond to the outputs offind_top_aurocs_legal
orfind_top_aurocs_contents
. - To show the top- and random-activating board states, run analysis.py's
show_top_activating(N, marked_position=M)
. N is again the feature number, and M is the board position that will be marked with a red circle.