Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Mar 22, 2024
1 parent e7b238f commit af680e2
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 224 deletions.
17 changes: 13 additions & 4 deletions sae_analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import plotly
import plotly.express as px
import torch
import wandb
from sae_vis.data_fetching_fns import get_feature_data
from sae_vis.data_storing_fns import FeatureVisParams
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_training.utils import LMSparseAutoencoderSessionloader


Expand Down Expand Up @@ -184,7 +184,11 @@ def get_index_to_resume_from(self):
@torch.no_grad()
def get_feature_property_df(self):
sparse_autoencoder = self.sparse_autoencoder
feature_sparsity = self.feature_sparsity
feature_sparsity = (
self.feature_sparsity
if self.feature_sparsity is not None
else torch.tensor(0)
)

W_dec_normalized = (
sparse_autoencoder.W_dec.cpu()
Expand Down Expand Up @@ -305,6 +309,11 @@ def run(self):
if self.use_wandb:
wandb.log({"time/time_to_get_tokens": end - start})

vocab_dict = cast(Any, self.model.tokenizer).vocab
vocab_dict = {
v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()
}

with torch.no_grad():
for interesting_features in tqdm(feature_idx):
print(interesting_features)
Expand All @@ -330,7 +339,7 @@ def run(self):
)

for i, test_idx in enumerate(feature_data.keys()):
html_str = feature_data[test_idx].get_html()
html_str = feature_data[test_idx].get_html(vocab_dict=vocab_dict)
with open(
f"{self.dashboard_folder}/data_{test_idx:04}.html", "w"
) as f:
Expand All @@ -348,7 +357,7 @@ def run(self):
# also upload as html to dashboard
wandb.log(
{
f"features/feature_dashboard": wandb.Html(
"features/feature_dashboard": wandb.Html(
f"{self.dashboard_folder}/data_{test_idx:04}.html"
)
},
Expand Down
1 change: 1 addition & 0 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Optional, cast

import torch

import wandb


Expand Down
2 changes: 1 addition & 1 deletion sae_training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import pandas as pd
import torch
import wandb
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.sparse_autoencoder import SparseAutoencoder

Expand Down
1 change: 0 additions & 1 deletion sae_training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, cast

import wandb

from sae_training.config import LanguageModelSAERunnerConfig

# from sae_training.activation_store import ActivationStore
Expand Down
2 changes: 1 addition & 1 deletion sae_training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import einops
import torch
import wandb

import wandb
from sae_training.sparse_autoencoder import SparseAutoencoder
from sae_training.toy_models import Config as ToyConfig
from sae_training.toy_models import Model as ToyModel
Expand Down
2 changes: 1 addition & 1 deletion sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, cast

import torch
import wandb
from torch.optim import Adam
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.evals import run_evals
from sae_training.geometric_median import compute_geometric_median
Expand Down
2 changes: 1 addition & 1 deletion sae_training/train_sae_on_toy_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, cast

import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from sae_training.sparse_autoencoder import SparseAutoencoder


Expand Down
254 changes: 52 additions & 202 deletions tutorials/generating_sae_dashboards.ipynb

Large diffs are not rendered by default.

46 changes: 33 additions & 13 deletions tutorials/logits_lens_with_features.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
"\n",
"# if desired, open the features in neuronpedia\n",
"for feature in tmp_df.feature:\n",
" open_neuronpedia(feature, layer = layer)"
" open_neuronpedia(feature, layer=layer)"
]
},
{
Expand All @@ -227,7 +227,7 @@
"outputs": [],
"source": [
"# get the vocab we need to filter to formulate token sets.\n",
"vocab = model.tokenizer.get_vocab() # type: ignore\n",
"vocab = model.tokenizer.get_vocab() # type: ignore\n",
"\n",
"# make a regex dictionary to specify more sets.\n",
"regex_dict = {\n",
Expand Down Expand Up @@ -269,7 +269,7 @@
"for token_set_name, gene_set in sorted(\n",
" all_token_sets.items(), key=lambda x: len(x[1]), reverse=True\n",
"):\n",
" tokens = [model.to_string(id) for id in list(gene_set)][:10] # type: ignore\n",
" tokens = [model.to_string(id) for id in list(gene_set)][:10] # type: ignore\n",
" print(f\"{token_set_name}, has {len(gene_set)} genes\")\n",
" print(tokens)\n",
" print(\"----\")"
Expand Down Expand Up @@ -357,7 +357,9 @@
"outputs": [],
"source": [
"token_sets_index = [\"1 digit\", \"2 digits\", \"3 digits\", \"4 digits\"]\n",
"token_set_selected = {k: set(v) for k, v in all_token_sets.items() if k in token_sets_index}\n",
"token_set_selected = {\n",
" k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n",
"}\n",
"df_enrichment_scores = get_enrichment_df(\n",
" dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n",
")\n",
Expand All @@ -371,7 +373,9 @@
"outputs": [],
"source": [
"token_sets_index = [\"nltk_pos_PRP\", \"nltk_pos_VBZ\", \"nltk_pos_NNP\"]\n",
"token_set_selected = {k: set(v) for k, v in all_token_sets.items() if k in token_sets_index}\n",
"token_set_selected = {\n",
" k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n",
"}\n",
"df_enrichment_scores = get_enrichment_df(\n",
" dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n",
")\n",
Expand All @@ -385,7 +389,9 @@
"outputs": [],
"source": [
"token_sets_index = [\"nltk_pos_VBN\", \"nltk_pos_VBG\", \"nltk_pos_VB\", \"nltk_pos_VBD\"]\n",
"token_set_selected = {k: set(v) for k, v in all_token_sets.items() if k in token_sets_index}\n",
"token_set_selected = {\n",
" k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n",
"}\n",
"df_enrichment_scores = get_enrichment_df(\n",
" dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n",
")\n",
Expand All @@ -399,7 +405,9 @@
"outputs": [],
"source": [
"token_sets_index = [\"nltk_pos_WP\", \"nltk_pos_RBR\", \"nltk_pos_WDT\", \"nltk_pos_RB\"]\n",
"token_set_selected = {k: set(v) for k, v in all_token_sets.items() if k in token_sets_index}\n",
"token_set_selected = {\n",
" k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n",
"}\n",
"df_enrichment_scores = get_enrichment_df(\n",
" dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n",
")\n",
Expand All @@ -413,7 +421,9 @@
"outputs": [],
"source": [
"token_sets_index = [\"a\", \"e\", \"i\", \"o\", \"u\"]\n",
"token_set_selected = {k: set(v) for k, v in all_token_sets.items() if k in token_sets_index}\n",
"token_set_selected = {\n",
" k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n",
"}\n",
"df_enrichment_scores = get_enrichment_df(\n",
" dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n",
")\n",
Expand All @@ -427,7 +437,9 @@
"outputs": [],
"source": [
"token_sets_index = [\"negative_words\", \"positive_words\"]\n",
"token_set_selected = {k: set(v) for k, v in all_token_sets.items() if k in token_sets_index}\n",
"token_set_selected = {\n",
" k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n",
"}\n",
"df_enrichment_scores = get_enrichment_df(\n",
" dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n",
")\n",
Expand Down Expand Up @@ -469,7 +481,9 @@
"outputs": [],
"source": [
"token_sets_index = [\"contains_close_bracket\", \"contains_open_bracket\"]\n",
"token_set_selected = {k: set(v) for k, v in all_token_sets.items() if k in token_sets_index}\n",
"token_set_selected = {\n",
" k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n",
"}\n",
"df_enrichment_scores = get_enrichment_df(\n",
" dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n",
")\n",
Expand All @@ -495,7 +509,9 @@
" \"2000's\",\n",
" \"2010's\",\n",
"]\n",
"token_set_selected = {k: set(v) for k, v in all_token_sets.items() if k in token_sets_index}\n",
"token_set_selected = {\n",
" k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n",
"}\n",
"df_enrichment_scores = get_enrichment_df(\n",
" dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n",
")\n",
Expand All @@ -509,7 +525,9 @@
"outputs": [],
"source": [
"token_sets_index = [\"positive_words\", \"negative_words\"]\n",
"token_set_selected = {k: set(v) for k, v in all_token_sets.items() if k in token_sets_index}\n",
"token_set_selected = {\n",
" k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n",
"}\n",
"df_enrichment_scores = get_enrichment_df(\n",
" dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n",
")\n",
Expand All @@ -523,7 +541,9 @@
"outputs": [],
"source": [
"token_sets_index = [\"boys_names\", \"girls_names\"]\n",
"token_set_selected = {k: set(v) for k, v in all_token_sets.items() if k in token_sets_index}\n",
"token_set_selected = {\n",
" k: set(v) for k, v in all_token_sets.items() if k in token_sets_index\n",
"}\n",
"df_enrichment_scores = get_enrichment_df(\n",
" dec_projection_onto_W_U, features_ordered_by_skew, token_set_selected\n",
")\n",
Expand Down

0 comments on commit af680e2

Please sign in to comment.