Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloomAus committed Mar 22, 2024
1 parent e814054 commit 4c03b3d
Show file tree
Hide file tree
Showing 19 changed files with 318 additions and 2,539 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ eindex = {git = "https://github.com/callummcdougall/eindex.git"}
datasets = "^2.17.1"
babe = "^0.0.7"
nltk = "^3.8.1"
sae-vis = {git = "https://github.com/callummcdougall/sae_vis.git"}


[tool.poetry.group.dev.dependencies]
Expand Down
81 changes: 31 additions & 50 deletions sae_analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
# flake8: noqa: E402
# TODO: are these sys.path.append calls really necessary?

import sys
from typing import Any, cast

sys.path.append("..")
sys.path.append("../..")
import os
from typing import Any, Optional, cast

# set TOKENIZERS_PARALLELISM to false to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand All @@ -18,20 +11,22 @@
import plotly
import plotly.express as px
import torch
import wandb
from sae_vis.data_fetching_fns import get_feature_data
from sae_vis.data_storing_fns import FeatureVisParams
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_analysis.visualizer.data_fns import get_feature_data
from sae_training.utils import LMSparseAutoencoderSessionloader


class DashboardRunner:

def __init__(
self,
sae_path: str | None = None,
sae_path: Optional[str] = None,
dashboard_parent_folder: str = "./feature_dashboards",
wandb_artifact_path: str | None = None,
wandb_artifact_path: Optional[str] = None,
init_session: bool = True,
# token pars
n_batches_to_sample_from: int = 2**12,
Expand All @@ -43,29 +38,9 @@ def __init__(
# util pars
use_wandb: bool = False,
continue_existing_dashboard: bool = True,
final_index: int | None = None,
final_index: Optional[int] = None,
):
"""
# # test it
# runner = DashboardRunner(
# sae_path = None,
# dashboard_parent_folder = "../feature_dashboards",
# wandb_artifact_path = "jbloom/mats_sae_training_gpt2_small_resid_pre_5/sparse_autoencoder_gpt2-small_blocks.2.hook_resid_pre_24576:v19",
# init_session = True,
# n_batches_to_sample_from = 2**12,
# n_prompts_to_select = 4096*6,
# n_features_at_a_time = 128,
# max_batch_size = 256,
# buffer_tokens = 8,
# use_wandb = True,
# continue_existing_dashboard = True,
# )
# runner.run()
"""
""" """

if wandb_artifact_path is not None:
artifact_dir = f"artifacts/{wandb_artifact_path.split('/')[2]}"
Expand Down Expand Up @@ -103,6 +78,7 @@ def __init__(
else:
assert sae_path is not None
self.sae_path = sae_path
self.feature_sparsity = None

if init_session:
self.init_sae_session()
Expand Down Expand Up @@ -152,6 +128,7 @@ def init_sae_session(self):
sae_group,
self.activation_store,
) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path)
# TODO: handle multiple autoencoders
self.sparse_autoencoder = sae_group.autoencoders[0]

def get_tokens(
Expand All @@ -176,15 +153,16 @@ def get_tokens(

def get_index_to_resume_from(self):
i = 0
assert self.n_features is not None # keep pyright happy
for i in range(self.n_features):
if not os.path.exists(f"{self.dashboard_folder}/data_{i:04}.html"):
break

assert self.sparse_autoencoder.cfg.d_sae is not None # keep pyright happy
assert self.final_index is not None # keep pyright happy
n_features = self.sparse_autoencoder.cfg.d_sae
n_features_at_a_time = self.n_features_at_a_time
id_of_last_feature_without_dashboard = i
assert self.final_index is not None # keep pyright happy
n_features_remaining = self.final_index - id_of_last_feature_without_dashboard
n_batches_to_do = n_features_remaining // n_features_at_a_time
if self.final_index == n_features:
Expand Down Expand Up @@ -258,11 +236,11 @@ def run(self):
self.init_sae_session()

# generate all the plots
if self.use_wandb:
if self.use_wandb and self.feature_sparsity is not None:
feature_property_df = self.get_feature_property_df()

fig = px.histogram(
feature_property_df.log_feature_sparsity,
self.feature_sparsity + 1e-10,
nbins=100,
log_x=False,
title="Feature sparsity",
Expand Down Expand Up @@ -303,10 +281,10 @@ def run(self):
)
wandb.log({"plots/scatter_matrix": wandb.Html(plotly.io.to_html(fig))})

assert self.sparse_autoencoder.cfg.d_sae is not None # keep pyright happy
self.n_features = self.sparse_autoencoder.cfg.d_sae
id_to_start_from = self.get_index_to_resume_from()
id_to_end_at = self.n_features if self.final_index is None else self.final_index
assert id_to_end_at is not None # keep pyright happy

# divide into batches
feature_idx = torch.tensor(range(id_to_start_from, id_to_end_at))
Expand All @@ -330,26 +308,29 @@ def run(self):
with torch.no_grad():
for interesting_features in tqdm(feature_idx):
print(interesting_features)
feature_data = get_feature_data(
encoder=self.sparse_autoencoder,
# encoder_B=sparse_autoencoder,
model=self.model,

feature_vis_params = FeatureVisParams(
hook_point=self.sparse_autoencoder.cfg.hook_point,
hook_point_layer=self.sparse_autoencoder.cfg.hook_point_layer,
hook_point_head_index=None,
tokens=tokens,
feature_idx=interesting_features,
max_batch_size=self.max_batch_size,
left_hand_k=3,
buffer=(self.buffer_tokens, self.buffer_tokens),
n_groups=10,
minibatch_size_features=256,
minibatch_size_tokens=64,
first_group_size=20,
other_groups_size=5,
buffer=(self.buffer_tokens, self.buffer_tokens),
features=interesting_features,
verbose=True,
include_left_tables=False,
)

feature_data = get_feature_data(
encoder=self.sparse_autoencoder, # type: ignore
model=self.model,
tokens=tokens,
fvp=feature_vis_params,
)

for i, test_idx in enumerate(feature_data.keys()):
html_str = feature_data[test_idx].get_all_html()
html_str = feature_data[test_idx].get_html()
with open(
f"{self.dashboard_folder}/data_{test_idx:04}.html", "w"
) as f:
Expand Down
18 changes: 0 additions & 18 deletions sae_analysis/visualizer/README.md

This file was deleted.

Empty file.
28 changes: 0 additions & 28 deletions sae_analysis/visualizer/css/general.css

This file was deleted.

131 changes: 0 additions & 131 deletions sae_analysis/visualizer/css/sequences.css

This file was deleted.

17 changes: 0 additions & 17 deletions sae_analysis/visualizer/css/tables.css

This file was deleted.

Loading

0 comments on commit 4c03b3d

Please sign in to comment.