Skip to content

Commit

Permalink
Merge branch 'main' into refactor-train-sae
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus authored Mar 22, 2024
2 parents 01978e6 + bcb9a52 commit 0acdcb3
Show file tree
Hide file tree
Showing 23 changed files with 274 additions and 2,578 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
extend-ignore = E203, E266, E501, W503, E721, F722, E731
extend-ignore = E203, E266, E501, W503, E721, F722, E731, E402
max-line-length = 79
max-complexity = 25
extend-select = E9, F63, F7, F82
Expand Down
84 changes: 61 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,37 +1,25 @@
<img width="1308" alt="Screenshot 2024-03-21 at 3 08 28 pm" src="https://github.com/jbloomAus/mats_sae_training/assets/69127271/209012ec-a779-4036-b4be-7b7739ea87f6">

# MATS SAE Training

[![build](https://github.com/jbloomAus/mats_sae_training/actions/workflows/tests.yml/badge.svg)](https://github.com/jbloomAus/mats_sae_training/actions/workflows/tests.yml)

This codebase contains training scripts and analysis code for Sparse AutoEncoders. I wasn't planning to share this codebase initially but I've recieved feedback that others have found it useful so I'm going to slowly transition it to be a more serious repo (formating/linting/testing etc.). In the mean time, please feel free to add Pull Requests or make issues if you have any trouble with it.
The MATS SAE training codebase (we'll rename it soon) exists to help researchers:
- Train sparse autoencoders.
- Analyse sparse autoencoders and neural network internals.
- Generate insights which make it easier to create safe and aligned AI systems.

## Quick Start

## Set Up
### Set Up

This project uses [Poetry](https://python-poetry.org/) for dependency management. Ensure Poetry is installed, then to install the dependencies, run:

```
poetry install
```

## Background

We highly recommend this [tutorial](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab).

## Code Overview

The codebase contains 2 folders worth caring about:

- sae_training: The main body of the code is here. Everything required for training SAEs.
- sae_analysis: This code is mainly house the feature visualizer code we use to generate dashboards. It was written by Callum McDougal but I've ported it here with permission and edited it to work with a few different activation types.

Some other folders:

- tutorials: These aren't well maintained but I'll aim to clean them up soon.
- tests: When first developing the codebase, I was writing more tests. I have no idea whether they are currently working!

I've been commiting my research code to the `Research` folder but am not expecting other people use or look at that.

## Loading Sparse Autoencoders from Huggingface
### Loading Sparse Autoencoders from Huggingface

[Previously trained sparse autoencoders](https://huggingface.co/jbloom/GPT2-Small-SAEs) can be loaded from huggingface with close to single line of code. For more details and performance metrics for these sparse autoencoder, read my [blog post](https://www.alignmentforum.org/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream).

Expand All @@ -58,9 +46,58 @@ path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
log_feature_sparsity = torch.load(path, map_location=sparse_autoencoder.cfg.device)

```
### Background

We highly recommend this [tutorial](https://www.lesswrong.com/posts/LnHowHgmrMbWtpkxx/intro-to-superposition-and-sparse-autoencoders-colab).


## High Level

### Motivation

- **Accelerate SAE Research**: Support fast experimentation to understand SAEs and improve SAE training so we can train SAEs on larger and more diverse models.
- **Make Research like Play**: Support research into language model internals via SAEs. Good tooling can make research tremendously exciting and enjoyable. Balancing modifiability and reliability with ease of understanding / access is the name of the game here.
- **Build an awesome community**: Mechanistic Interpretability already has an awesome community but as that community grows, it makes sense that there will be niches. I'd love to build a great community around Sparse Autoencoders.

### Goals

#### **SAE Training**: SAE Training features will fit into a number of categories including:
- **Making it easy to train SAEs**: Training SAEs is hard for a number of reasons and so making it easy for people to train SAEs with relatively little expertise seems like the main way this codebase will create value.
- **Training SAEs on more models**: Supporting training of SAEs on more models, architectures, different activations within those models.
- **Being better at training SAEs**: Enabling methodological changes which may improve SAE performance as measured by reconstruction loss, Cross Entropy Loss when using reconstructed activation, L1 loss, L0 and interpretability of features as well as improving speed of training or reducing the compute resources required to train SAEs.
- **Being better at measuring SAE Performance**: How do we know when SAEs are doing what we want them to? Improving training metrics should allow better decisions about which methods to use and which hyperparameters choices we make.
- **Training SAE variants**: People are already training “Transcoders” which map from one activation to another (such as before / after an MLP layer). These can be easily supported with a few changes. Other variants will come in time and

#### **Analysis with SAEs**: Using SAEs to understand neural network internals is an exciting, but complicated task.
- **Feature-wise Interpretability**: This looks something like "for each feature, have as much knowledge about it as possible". Part of this will feature dashboard improvements, or supporting better integrations with Neuronpedia.
- **Mechanistic Interpretability**: This comprises the more traditional kinds of Mechanistic Interpretability which TransformerLens supports and should be supported by this codebase. Making it easy to patch, ablate or otherwise intervene on features so as to find circuits will likely speed up lots of researchers.

### Other Stuff

I think there are lots of other types of analysis that could be done in the future with SAE features. I've already explored many different types of statistical tests which can reveal interesting properties of features. There are also things like saliency mapping and attribution techniques which it would be nice to support.
- Accessibility and Code Quality: The codebase won’t be used if it doesn’t work and it also won’t get used if it’s too hard to understand, modify or read.
Making the code accessible: This involves tasks like turning the code base into a python package.
- Knowing how the code is supposed to work: Is the code well-documented? This will require docstrings, tutorials and links to related work and publications. Getting aligned on what the code does is critical to sharing a resource like this.
- Knowing the code works as intended: All code should be tested. Unit tests and acceptance tests are both important.
- Knowing the code is actually performant: This will ensure code works as intended. However deep learning introduces lots of complexity which makes actually running benchmarks essential to having confidence in the code.


## Code Overview

The codebase contains 2 folders worth caring about:

- sae_training: The main body of the code is here. Everything required for training SAEs.
- sae_analysis: This code is mainly house the feature visualizer code we use to generate dashboards. It was written by Callum McDougal but I've ported it here with permission and edited it to work with a few different activation types.

Some other folders:

- tutorials: These aren't well maintained but I'll aim to clean them up soon.
- tests: When first developing the codebase, I was writing more tests. I have no idea whether they are currently working!

I've been commiting my research code to the `Research` folder but am not expecting other people use or look at that.


## Training a Sparse Autoencoder on a Language Model
### Training your own Sparse Autoencoder

Sparse Autoencoders can be intimidating at first but it's fairly simple to train one once you know what each part of the config does. I've created a config class which you instantiate and pass to the runner which will complete your training run and log it's progress to wandb.

Expand Down Expand Up @@ -158,7 +195,8 @@ model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader
## Tutorials

I wrote a tutorial to show users how to do some basic exploration of their SAE.
- `evaluating_your_sae.ipynb`: A quick/dirty notebook showing how to check L0 and Prediction loss with your SAE, as well as showing how to generate interactive dashboards using Callum's reporduction of [Anthropics interface](https://transformer-circuits.pub/2023/monosemantic-features#setup-interface).
- `evaluating_your_sae.ipynb`: A quick/dirty notebook showing how to check L0 and Prediction loss with your SAE, as well as showing how to generate interactive dashboards using Callum's reporduction of [Anthropics interface](https://transformer-circuits.pub/2023/monosemantic-features#setup-interface).
- `logits_lens_with_features.ipynb`: A notebook showing how to reproduce the analysis from this [LessWrong post](https://www.lesswrong.com/posts/qykrYY6rXXM7EEs8Q/understanding-sae-features-with-the-logit-lens).

## Example Dashboard

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ eindex = {git = "https://github.com/callummcdougall/eindex.git"}
datasets = "^2.17.1"
babe = "^0.0.7"
nltk = "^3.8.1"
sae-vis = {git = "https://github.com/callummcdougall/sae_vis.git"}


[tool.poetry.group.dev.dependencies]
Expand Down
92 changes: 41 additions & 51 deletions sae_analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
# flake8: noqa: E402
# TODO: are these sys.path.append calls really necessary?

import sys
from typing import Any, cast

sys.path.append("..")
sys.path.append("../..")
import os
from typing import Any, Optional, cast

# set TOKENIZERS_PARALLELISM to false to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand All @@ -18,20 +11,22 @@
import plotly
import plotly.express as px
import torch
from sae_vis.data_fetching_fns import get_feature_data
from sae_vis.data_storing_fns import FeatureVisParams
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_analysis.visualizer.data_fns import get_feature_data
from sae_training.utils import LMSparseAutoencoderSessionloader


class DashboardRunner:

def __init__(
self,
sae_path: str | None = None,
sae_path: Optional[str] = None,
dashboard_parent_folder: str = "./feature_dashboards",
wandb_artifact_path: str | None = None,
wandb_artifact_path: Optional[str] = None,
init_session: bool = True,
# token pars
n_batches_to_sample_from: int = 2**12,
Expand All @@ -43,29 +38,9 @@ def __init__(
# util pars
use_wandb: bool = False,
continue_existing_dashboard: bool = True,
final_index: int | None = None,
final_index: Optional[int] = None,
):
"""
# # test it
# runner = DashboardRunner(
# sae_path = None,
# dashboard_parent_folder = "../feature_dashboards",
# wandb_artifact_path = "jbloom/mats_sae_training_gpt2_small_resid_pre_5/sparse_autoencoder_gpt2-small_blocks.2.hook_resid_pre_24576:v19",
# init_session = True,
# n_batches_to_sample_from = 2**12,
# n_prompts_to_select = 4096*6,
# n_features_at_a_time = 128,
# max_batch_size = 256,
# buffer_tokens = 8,
# use_wandb = True,
# continue_existing_dashboard = True,
# )
# runner.run()
"""
""" """

if wandb_artifact_path is not None:
artifact_dir = f"artifacts/{wandb_artifact_path.split('/')[2]}"
Expand Down Expand Up @@ -103,6 +78,7 @@ def __init__(
else:
assert sae_path is not None
self.sae_path = sae_path
self.feature_sparsity = None

if init_session:
self.init_sae_session()
Expand Down Expand Up @@ -152,6 +128,7 @@ def init_sae_session(self):
sae_group,
self.activation_store,
) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path)
# TODO: handle multiple autoencoders
self.sparse_autoencoder = sae_group.autoencoders[0]

def get_tokens(
Expand All @@ -176,15 +153,16 @@ def get_tokens(

def get_index_to_resume_from(self):
i = 0
assert self.n_features is not None # keep pyright happy
for i in range(self.n_features):
if not os.path.exists(f"{self.dashboard_folder}/data_{i:04}.html"):
break

assert self.sparse_autoencoder.cfg.d_sae is not None # keep pyright happy
assert self.final_index is not None # keep pyright happy
n_features = self.sparse_autoencoder.cfg.d_sae
n_features_at_a_time = self.n_features_at_a_time
id_of_last_feature_without_dashboard = i
assert self.final_index is not None # keep pyright happy
n_features_remaining = self.final_index - id_of_last_feature_without_dashboard
n_batches_to_do = n_features_remaining // n_features_at_a_time
if self.final_index == n_features:
Expand All @@ -206,7 +184,11 @@ def get_index_to_resume_from(self):
@torch.no_grad()
def get_feature_property_df(self):
sparse_autoencoder = self.sparse_autoencoder
feature_sparsity = self.feature_sparsity
feature_sparsity = (
self.feature_sparsity
if self.feature_sparsity is not None
else torch.tensor(0)
)

W_dec_normalized = (
sparse_autoencoder.W_dec.cpu()
Expand Down Expand Up @@ -258,11 +240,11 @@ def run(self):
self.init_sae_session()

# generate all the plots
if self.use_wandb:
if self.use_wandb and self.feature_sparsity is not None:
feature_property_df = self.get_feature_property_df()

fig = px.histogram(
feature_property_df.log_feature_sparsity,
self.feature_sparsity + 1e-10,
nbins=100,
log_x=False,
title="Feature sparsity",
Expand Down Expand Up @@ -303,10 +285,10 @@ def run(self):
)
wandb.log({"plots/scatter_matrix": wandb.Html(plotly.io.to_html(fig))})

assert self.sparse_autoencoder.cfg.d_sae is not None # keep pyright happy
self.n_features = self.sparse_autoencoder.cfg.d_sae
id_to_start_from = self.get_index_to_resume_from()
id_to_end_at = self.n_features if self.final_index is None else self.final_index
assert id_to_end_at is not None # keep pyright happy

# divide into batches
feature_idx = torch.tensor(range(id_to_start_from, id_to_end_at))
Expand All @@ -327,29 +309,37 @@ def run(self):
if self.use_wandb:
wandb.log({"time/time_to_get_tokens": end - start})

vocab_dict = cast(Any, self.model.tokenizer).vocab
vocab_dict = {
v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()
}

with torch.no_grad():
for interesting_features in tqdm(feature_idx):
print(interesting_features)
feature_data = get_feature_data(
encoder=self.sparse_autoencoder,
# encoder_B=sparse_autoencoder,
model=self.model,

feature_vis_params = FeatureVisParams(
hook_point=self.sparse_autoencoder.cfg.hook_point,
hook_point_layer=self.sparse_autoencoder.cfg.hook_point_layer,
hook_point_head_index=None,
tokens=tokens,
feature_idx=interesting_features,
max_batch_size=self.max_batch_size,
left_hand_k=3,
buffer=(self.buffer_tokens, self.buffer_tokens),
n_groups=10,
minibatch_size_features=256,
minibatch_size_tokens=64,
first_group_size=20,
other_groups_size=5,
buffer=(self.buffer_tokens, self.buffer_tokens),
features=interesting_features,
verbose=True,
include_left_tables=False,
)

feature_data = get_feature_data(
encoder=self.sparse_autoencoder, # type: ignore
model=self.model,
tokens=tokens,
fvp=feature_vis_params,
)

for i, test_idx in enumerate(feature_data.keys()):
html_str = feature_data[test_idx].get_all_html()
html_str = feature_data[test_idx].get_html(vocab_dict=vocab_dict)
with open(
f"{self.dashboard_folder}/data_{test_idx:04}.html", "w"
) as f:
Expand All @@ -367,7 +357,7 @@ def run(self):
# also upload as html to dashboard
wandb.log(
{
f"features/feature_dashboard": wandb.Html(
"features/feature_dashboard": wandb.Html(
f"{self.dashboard_folder}/data_{test_idx:04}.html"
)
},
Expand Down
18 changes: 0 additions & 18 deletions sae_analysis/visualizer/README.md

This file was deleted.

Empty file.
28 changes: 0 additions & 28 deletions sae_analysis/visualizer/css/general.css

This file was deleted.

Loading

0 comments on commit 0acdcb3

Please sign in to comment.