diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..138a4973 --- /dev/null +++ b/.flake8 @@ -0,0 +1,7 @@ +[flake8] +ignore = E203, E266, E501, W503 +max-line-length = 79 +max-complexity = 10 +select = E9, F63, F7, F82 +show-source = true +statistics = true diff --git a/.gitignore b/.gitignore index 29ebafe0..754fd271 100644 --- a/.gitignore +++ b/.gitignore @@ -170,4 +170,4 @@ activations/ *.DS_Store feature_dashboards/ -research/ \ No newline at end of file +research/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..c45c4453 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + args: [--maxkb=250000] +- repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black +- repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + additional_dependencies: [ + 'flake8-blind-except', + 'flake8-docstrings', + 'flake8-bugbear', + 'flake8-comprehensions', + 'flake8-docstrings', + 'flake8-implicit-str-concat', + 'pydocstyle>=5.0.0', + ] diff --git a/.pylintrc b/.pylintrc index 0127ca79..d4edf945 100644 --- a/.pylintrc +++ b/.pylintrc @@ -16,4 +16,4 @@ default-docstring-type = numpy max-line-length = 100 [MESSAGES CONTROL] -disable = C0330, C0326, C0199, C0411, C103, C0303, C0304 \ No newline at end of file +disable = C0330, C0326, C0199, C0411, C103, C0303, C0304 diff --git a/.vscode/settings.json b/.vscode/settings.json index 6dedeb83..cdf373ce 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -17,4 +17,4 @@ "black" ], "editor.defaultFormatter": "mikoz.black-py", -} \ No newline at end of file +} diff --git a/requirements.txt b/requirements.txt index 7737a85c..6c962a30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ black==23.11.0 pytest==7.4.3 pytest-cov==4.1.0 pre-commit==3.6.0 -git+https://github.com/callummcdougall/eindex.git \ No newline at end of file +git+https://github.com/callummcdougall/eindex.git diff --git a/sae_analysis/dashboard_runner.py b/sae_analysis/dashboard_runner.py index a2d1b9ea..3ff2bc7b 100644 --- a/sae_analysis/dashboard_runner.py +++ b/sae_analysis/dashboard_runner.py @@ -22,31 +22,27 @@ from sae_training.utils import LMSparseAutoencoderSessionloader -class DashboardRunner(): - +class DashboardRunner: def __init__( self, sae_path: str = None, dashboard_parent_folder: str = "./feature_dashboards", wandb_artifact_path: str = None, init_session: bool = True, - # token pars n_batches_to_sample_from: int = 2**12, - n_prompts_to_select: int = 4096*6, - + n_prompts_to_select: int = 4096 * 6, # sampling pars n_features_at_a_time: int = 1024, max_batch_size: int = 256, buffer_tokens: int = 8, - # util pars use_wandb: bool = False, continue_existing_dashboard: bool = True, final_index: int = None, ): - ''' - # # test it + """ + # # test it # runner = DashboardRunner( # sae_path = None, @@ -64,11 +60,10 @@ def __init__( # runner.run() - - ''' - - if wandb_artifact_path is not None: + """ + + if wandb_artifact_path is not None: artifact_dir = f"artifacts/{wandb_artifact_path.split('/')[2]}" if not os.path.exists(artifact_dir): print("Downloading artifact") @@ -77,92 +72,106 @@ def __init__( artifact_dir = artifact.download() path_to_artifact = f"{artifact_dir}/{os.listdir(artifact_dir)[0]}" # feature sparsity - feature_sparsity_path = self.get_feature_sparsity_path(wandb_artifact_path) + feature_sparsity_path = self.get_feature_sparsity_path( + wandb_artifact_path + ) artifact = run.use_artifact(feature_sparsity_path) artifact_dir = artifact.download() # add it as a property - self.feature_sparsity = torch.load(f"{artifact_dir}/{os.listdir(artifact_dir)[0]}") + self.feature_sparsity = torch.load( + f"{artifact_dir}/{os.listdir(artifact_dir)[0]}" + ) else: print("Artifact already downloaded") path_to_artifact = f"{artifact_dir}/{os.listdir(artifact_dir)[0]}" - - feature_sparsity_path = self.get_feature_sparsity_path(wandb_artifact_path) + + feature_sparsity_path = self.get_feature_sparsity_path( + wandb_artifact_path + ) artifact_dir = f"artifacts/{feature_sparsity_path.split('/')[2]}" feature_sparsity_file = os.listdir(artifact_dir)[0] - self.feature_sparsity = torch.load(f"{artifact_dir}/{feature_sparsity_file}") - + self.feature_sparsity = torch.load( + f"{artifact_dir}/{feature_sparsity_file}" + ) + self.sae_path = path_to_artifact - else: + else: assert sae_path is not None self.sae_path = sae_path - + if init_session: self.init_sae_session() - + self.n_features_at_a_time = n_features_at_a_time self.max_batch_size = max_batch_size self.buffer_tokens = buffer_tokens self.use_wandb = use_wandb - self.final_index = final_index if final_index is not None else self.sparse_autoencoder.cfg.d_sae + self.final_index = ( + final_index + if final_index is not None + else self.sparse_autoencoder.cfg.d_sae + ) self.n_batches_to_sample_from = n_batches_to_sample_from self.n_prompts_to_select = n_prompts_to_select - - + # Deal with file structure if not os.path.exists(dashboard_parent_folder): os.makedirs(dashboard_parent_folder) - self.dashboard_folder = f"{dashboard_parent_folder}/{self.get_dashboard_folder_name()}" + self.dashboard_folder = ( + f"{dashboard_parent_folder}/{self.get_dashboard_folder_name()}" + ) if not os.path.exists(self.dashboard_folder): os.makedirs(self.dashboard_folder) - + if not continue_existing_dashboard: # check if there are files there and if so abort if len(os.listdir(self.dashboard_folder)) > 0: raise ValueError("Dashboard folder not empty. Aborting.") def get_feature_sparsity_path(self, wandb_artifact_path): - prefix = wandb_artifact_path.split(':')[0] + prefix = wandb_artifact_path.split(":")[0] return f"{prefix}_log_feature_sparsity:v9" - + def get_dashboard_folder_name(self): - model = self.sparse_autoencoder.cfg.model_name hook_point = self.sparse_autoencoder.cfg.hook_point d_sae = self.sparse_autoencoder.cfg.d_sae dashboard_folder_name = f"{model}_{hook_point}_{d_sae}" - + return dashboard_folder_name - + def init_sae_session(self): - - self.model, self.sparse_autoencoder, self.activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained( - self.sae_path - ) - - def get_tokens(self, n_batches_to_sample_from = 2**12, n_prompts_to_select = 4096*6): - ''' + ( + self.model, + self.sparse_autoencoder, + self.activation_store, + ) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path) + + def get_tokens( + self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6 + ): + """ Get the tokens needed for dashboard generation. - ''' - + """ + all_tokens_list = [] pbar = tqdm(range(n_batches_to_sample_from)) for _ in pbar: - batch_tokens = self.activation_store.get_batch_tokens() - batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][:batch_tokens.shape[0]] + batch_tokens = batch_tokens[torch.randperm(batch_tokens.shape[0])][ + : batch_tokens.shape[0] + ] all_tokens_list.append(batch_tokens) - + all_tokens = torch.cat(all_tokens_list, dim=0) all_tokens = all_tokens[torch.randperm(all_tokens.shape[0])] return all_tokens[:n_prompts_to_select] def get_index_to_resume_from(self): - for i in range(self.n_features): - if not os.path.exists(f"{self.dashboard_folder}/data_{i:04}.html"): - break + break n_features = self.sparse_autoencoder.cfg.d_sae n_features_at_a_time = self.n_features_at_a_time @@ -170,91 +179,119 @@ def get_index_to_resume_from(self): n_features_remaining = self.final_index - id_of_last_feature_without_dashboard n_batches_to_do = n_features_remaining // n_features_at_a_time if self.final_index == n_features: - id_to_start_from = max(0, n_features - (n_batches_to_do + 1) * n_features_at_a_time) + id_to_start_from = max( + 0, n_features - (n_batches_to_do + 1) * n_features_at_a_time + ) else: - id_to_start_from = 0 # testing purposes only - - + id_to_start_from = 0 # testing purposes only + print(f"File {i} does not exist") print(f"features left to do: {n_features_remaining}") print(f"id_to_start_from: {id_to_start_from}") - print(f"number of batches to do: {(n_features - id_to_start_from) // n_features_at_a_time}") - + print( + f"number of batches to do: {(n_features - id_to_start_from) // n_features_at_a_time}" + ) + return id_to_start_from - + @torch.no_grad() def get_feature_property_df(self): - - sparse_autoencoder= self.sparse_autoencoder + sparse_autoencoder = self.sparse_autoencoder feature_sparsity = self.feature_sparsity - - W_dec_normalized = sparse_autoencoder.W_dec.cpu()# / sparse_autoencoder.W_dec.cpu().norm(dim=-1, keepdim=True) - W_enc_normalized = sparse_autoencoder.W_enc.cpu() / sparse_autoencoder.W_enc.cpu().norm(dim=-1, keepdim=True) + + W_dec_normalized = ( + sparse_autoencoder.W_dec.cpu() + ) # / sparse_autoencoder.W_dec.cpu().norm(dim=-1, keepdim=True) + W_enc_normalized = ( + sparse_autoencoder.W_enc.cpu() + / sparse_autoencoder.W_enc.cpu().norm(dim=-1, keepdim=True) + ) d_e_projection = cosine_similarity(W_dec_normalized, W_enc_normalized.T) b_dec_projection = sparse_autoencoder.b_dec.cpu() @ W_dec_normalized.T - temp_df = pd.DataFrame({ - "log_feature_sparsity": feature_sparsity + 1e-10, - "d_e_projection": d_e_projection, - # "d_e_projection_normalized": d_e_projection_normalized, - "b_enc": sparse_autoencoder.b_enc.detach().cpu(), - "feature": [f"feature_{i}" for i in range(sparse_autoencoder.cfg.d_sae)], - "index": torch.arange(sparse_autoencoder.cfg.d_sae), - "dead_neuron": (feature_sparsity < -9).cpu(), - }) - + temp_df = pd.DataFrame( + { + "log_feature_sparsity": feature_sparsity + 1e-10, + "d_e_projection": d_e_projection, + # "d_e_projection_normalized": d_e_projection_normalized, + "b_enc": sparse_autoencoder.b_enc.detach().cpu(), + "feature": [ + f"feature_{i}" for i in range(sparse_autoencoder.cfg.d_sae) + ], + "index": torch.arange(sparse_autoencoder.cfg.d_sae), + "dead_neuron": (feature_sparsity < -9).cpu(), + } + ) + return temp_df - - + def run(self): - ''' + """ Generate the dashboard. - ''' - + """ + if self.use_wandb: # get name from wandb - random_suffix= str(uuid.uuid4())[:8] + random_suffix = str(uuid.uuid4())[:8] name = f"{self.get_dashboard_folder_name()}_{random_suffix}" run = wandb.init( project="feature_dashboards", config=self.sparse_autoencoder.cfg, - name = name, - tags = [ + name=name, + tags=[ f"model_{self.sparse_autoencoder.cfg.model_name}", f"hook_point_{self.sparse_autoencoder.cfg.hook_point}", - ] + ], ) - + if self.model is None: self.init_sae_session() - # generate all the plots if self.use_wandb: feature_property_df = self.get_feature_property_df() - - fig = px.histogram(runner.feature_sparsity+1e-10, nbins=100, log_x=False, title="Feature sparsity") - wandb.log({"plots/feature_density_histogram": wandb.Html(plotly.io.to_html(fig))}) - fig = px.histogram(self.sparse_autoencoder.b_enc.detach().cpu(), title = "b_enc", nbins = 100) + fig = px.histogram( + feature_property_df.log_feature_sparsity, + nbins=100, + log_x=False, + title="Feature sparsity", + ) + wandb.log( + {"plots/feature_density_histogram": wandb.Html(plotly.io.to_html(fig))} + ) + + fig = px.histogram( + self.sparse_autoencoder.b_enc.detach().cpu(), title="b_enc", nbins=100 + ) wandb.log({"plots/b_enc_histogram": wandb.Html(plotly.io.to_html(fig))}) - - fig = px.histogram(feature_property_df.d_e_projection, nbins = 100, title = "D/E projection") - wandb.log({"plots/d_e_projection_histogram": wandb.Html(plotly.io.to_html(fig))}) - - fig = px.histogram(self.sparse_autoencoder.b_dec.detach().cpu(), nbins=100, title = "b_dec projection onto W_dec") - wandb.log({"plots/b_dec_projection_histogram": wandb.Html(plotly.io.to_html(fig))}) - - fig = px.scatter_matrix(feature_property_df, - dimensions = ["log_feature_sparsity", "d_e_projection", "b_enc"], + + fig = px.histogram( + feature_property_df.d_e_projection, nbins=100, title="D/E projection" + ) + wandb.log( + {"plots/d_e_projection_histogram": wandb.Html(plotly.io.to_html(fig))} + ) + + fig = px.histogram( + self.sparse_autoencoder.b_dec.detach().cpu(), + nbins=100, + title="b_dec projection onto W_dec", + ) + wandb.log( + {"plots/b_dec_projection_histogram": wandb.Html(plotly.io.to_html(fig))} + ) + + fig = px.scatter_matrix( + feature_property_df, + dimensions=["log_feature_sparsity", "d_e_projection", "b_enc"], color="dead_neuron", hover_name="feature", opacity=0.2, height=800, - width =1400, + width=1400, ) wandb.log({"plots/scatter_matrix": wandb.Html(plotly.io.to_html(fig))}) - self.n_features = self.sparse_autoencoder.cfg.d_sae id_to_start_from = self.get_index_to_resume_from() @@ -264,19 +301,21 @@ def run(self): feature_idx = torch.tensor(range(id_to_start_from, id_to_end_at)) feature_idx = feature_idx.reshape(-1, self.n_features_at_a_time) feature_idx = [x.tolist() for x in feature_idx] - + print(f"Hook Point Layer: {self.sparse_autoencoder.cfg.hook_point_layer}") print(f"Hook Point: {self.sparse_autoencoder.cfg.hook_point}") print(f"Writing files to: {self.dashboard_folder}") # get tokens: start = time.time() - tokens = self.get_tokens(self.n_batches_to_sample_from, self.n_prompts_to_select) + tokens = self.get_tokens( + self.n_batches_to_sample_from, self.n_prompts_to_select + ) end = time.time() print(f"Time to get tokens: {end - start}") if self.use_wandb: wandb.log({"time/time_to_get_tokens": end - start}) - + with torch.no_grad(): for interesting_features in tqdm(feature_idx): print(interesting_features) @@ -290,45 +329,51 @@ def run(self): tokens=tokens, feature_idx=interesting_features, max_batch_size=self.max_batch_size, - left_hand_k = 3, - buffer = (self.buffer_tokens, self.buffer_tokens), - n_groups = 10, - first_group_size = 20, - other_groups_size = 5, - verbose = True, + left_hand_k=3, + buffer=(self.buffer_tokens, self.buffer_tokens), + n_groups=10, + first_group_size=20, + other_groups_size=5, + verbose=True, ) - + for i, test_idx in enumerate(feature_data.keys()): html_str = feature_data[test_idx].get_all_html() - with open(f"{self.dashboard_folder}/data_{test_idx:04}.html", "w") as f: + with open( + f"{self.dashboard_folder}/data_{test_idx:04}.html", "w" + ) as f: f.write(html_str) - + if i < 10 and self.use_wandb: # upload the html as an artifact artifact = wandb.Artifact(f"feature_{test_idx}", type="feature") - artifact.add_file(f"{self.dashboard_folder}/data_{test_idx:04}.html") + artifact.add_file( + f"{self.dashboard_folder}/data_{test_idx:04}.html" + ) run.log_artifact(artifact) - + # also upload as html to dashboard wandb.log( - {f"features/feature_dashboard": wandb.Html(f"{self.dashboard_folder}/data_{test_idx:04}.html")}, - step = test_idx - ) - + { + f"features/feature_dashboard": wandb.Html( + f"{self.dashboard_folder}/data_{test_idx:04}.html" + ) + }, + step=test_idx, + ) + # when done zip the folder - shutil.make_archive(self.dashboard_folder, 'zip', self.dashboard_folder) - + shutil.make_archive(self.dashboard_folder, "zip", self.dashboard_folder) + # then upload the zip as an artifact artifact = wandb.Artifact("dashboard", type="zipped_feature_dashboards") artifact.add_file(f"{self.dashboard_folder}.zip") run.log_artifact(artifact) - + # terminate the run run.finish() - + # delete the dashboard folder shutil.rmtree(self.dashboard_folder) - - return - + return diff --git a/sae_analysis/visualizer/README.md b/sae_analysis/visualizer/README.md index 15a9ac58..829f2484 100644 --- a/sae_analysis/visualizer/README.md +++ b/sae_analysis/visualizer/README.md @@ -15,4 +15,4 @@ This particular feature seems to be a fuzzy skip trigram, with the pattern being These visualisations were created using the GELU-1l model from Neel Nanda's HuggingFace library, as well as an autoencoder which he trained on its single layer of neuron activations (see [this Colab](https://colab.research.google.com/drive/1u8larhpxy8w4mMsJiSBddNOzFGj7_RTn) from Neel). -You can use my [Colab]() to generate more of these visualisations. You can use this [sae visualiser](https://www.perfectlynormal.co.uk/blog-sae) to navigate through the first thousand features of the aforementioned autoencoder. \ No newline at end of file +You can use my [Colab]() to generate more of these visualisations. You can use this [sae visualiser](https://www.perfectlynormal.co.uk/blog-sae) to navigate through the first thousand features of the aforementioned autoencoder. diff --git a/sae_analysis/visualizer/css/general.css b/sae_analysis/visualizer/css/general.css index dcb5f5fc..7b5404c4 100644 --- a/sae_analysis/visualizer/css/general.css +++ b/sae_analysis/visualizer/css/general.css @@ -25,4 +25,4 @@ table { } code { all: unset; -} \ No newline at end of file +} diff --git a/sae_analysis/visualizer/css/sequences.css b/sae_analysis/visualizer/css/sequences.css index b45b7f7f..bcef8d6e 100644 --- a/sae_analysis/visualizer/css/sequences.css +++ b/sae_analysis/visualizer/css/sequences.css @@ -129,6 +129,3 @@ table code { width: 50%; margin-right: -4px; } - - - diff --git a/sae_analysis/visualizer/css/tables.css b/sae_analysis/visualizer/css/tables.css index ffda1234..fdf82c6e 100644 --- a/sae_analysis/visualizer/css/tables.css +++ b/sae_analysis/visualizer/css/tables.css @@ -14,4 +14,4 @@ h4 { } .code-bold code { font-weight: bold; -} \ No newline at end of file +} diff --git a/sae_analysis/visualizer/data_fns.py b/sae_analysis/visualizer/data_fns.py index 446158c7..c6b1c486 100644 --- a/sae_analysis/visualizer/data_fns.py +++ b/sae_analysis/visualizer/data_fns.py @@ -46,8 +46,8 @@ class HistogramData: - ''' - Class for storing all the data necessary to construct a histogram (because e.g. + """ + Class for storing all the data necessary to construct a histogram (because e.g. for a vector with length `d_vocab`, we don't need to store it all!). This is initialised with a tensor of data, and it automatically calculates & stores @@ -55,9 +55,9 @@ class HistogramData: This isn't a dataclass, because the things we hold at the end are not the same as the things we start with! - ''' - def __init__(self, data: Tensor, n_bins: int, tickmode: str): + """ + def __init__(self, data: Tensor, n_bins: int, tickmode: str): if data.numel() == 0: self.bar_heights = [] self.bar_values = [] @@ -74,7 +74,7 @@ def __init__(self, data: Tensor, n_bins: int, tickmode: str): # calculate the heights of each bin bar_heights = torch.histc(data, bins=n_bins) bar_values = bin_edges[:-1] + bin_size / 2 - + # choose tickvalues (super hacky and terrible, should improve this) assert tickmode in ["ints", "5 ticks"] @@ -92,9 +92,13 @@ def __init__(self, data: Tensor, n_bins: int, tickmode: str): num_negative_ticks = 3 num_positive_ticks = int(max_value / tickrange) tick_vals = merge_lists( - reversed([-tickrange * i for i in range(1, 1+num_negative_ticks)]), # negative values (if exist) - [0], # zero (always is a tick) - [tickrange * i for i in range(1, 1+num_positive_ticks)] # positive values + reversed( + [-tickrange * i for i in range(1, 1 + num_negative_ticks)] + ), # negative values (if exist) + [0], # zero (always is a tick) + [ + tickrange * i for i in range(1, 1 + num_positive_ticks) + ], # positive values ) self.bar_heights = bar_heights.tolist() @@ -102,10 +106,9 @@ def __init__(self, data: Tensor, n_bins: int, tickmode: str): self.tick_vals = tick_vals - @dataclass class SequenceData: - ''' + """ Class to store data for a given sequence, which will be turned into a JavaScript visulisation. Before hover: @@ -119,7 +122,8 @@ class SequenceData: top5_logit_changes: list of the corresponding 5 changes in logits for those tokens bottom5_str_tokens: list of the bottom 5 logit-boosted tokens by this feature bottom5_logit_changes: list of the corresponding 5 changes in logits for those tokens - ''' + """ + token_ids: List[str] feat_acts: List[float] contribution_to_loss: List[float] @@ -134,64 +138,70 @@ def __len__(self): def __str__(self): return f"SequenceData({''.join(self.token_ids)})" - + def __post_init__(self): - '''Filters down the data, by deleting the "on hover" information if the activations are zero.''' - self.top5_logit_contributions, self.top5_token_ids = self._filter(self.top5_logit_contributions, self.top5_token_ids) - self.bottom5_logit_contributions, self.bottom5_token_ids = self._filter(self.bottom5_logit_contributions, self.bottom5_token_ids) + """Filters down the data, by deleting the "on hover" information if the activations are zero.""" + self.top5_logit_contributions, self.top5_token_ids = self._filter( + self.top5_logit_contributions, self.top5_token_ids + ) + self.bottom5_logit_contributions, self.bottom5_token_ids = self._filter( + self.bottom5_logit_contributions, self.bottom5_token_ids + ) def _filter(self, float_list: List[List[float]], int_list: List[List[str]]): float_list = [[f for f in floats if f != 0] for floats in float_list] - int_list = [[i for i, f in zip(ints, floats)] for ints, floats in zip(int_list, float_list)] + int_list = [ + [i for i, f in zip(ints, floats)] + for ints, floats in zip(int_list, float_list) + ] return float_list, int_list - class SequenceDataBatch: - ''' + """ Class to store a list of SequenceData objects at once, by passing in tensors or objects with an extra dimension at the start. Note, I'll be creating these objects by passing in objects which are either 2D (k seq_len) or 3D (k seq_len top5), but which are all lists (of strings/ints/floats). - ''' + """ + def __init__(self, **kwargs): self.seqs = [ SequenceData( - token_ids = kwargs["token_ids"][k], - feat_acts = kwargs["feat_acts"][k], - contribution_to_loss = kwargs["contribution_to_loss"][k], - repeat = kwargs["repeat"], - top5_token_ids = kwargs["top5_token_ids"][k], - top5_logit_contributions = kwargs["top5_logit_contributions"][k], - bottom5_token_ids = kwargs["bottom5_token_ids"][k], - bottom5_logit_contributions = kwargs["bottom5_logit_contributions"][k], + token_ids=kwargs["token_ids"][k], + feat_acts=kwargs["feat_acts"][k], + contribution_to_loss=kwargs["contribution_to_loss"][k], + repeat=kwargs["repeat"], + top5_token_ids=kwargs["top5_token_ids"][k], + top5_logit_contributions=kwargs["top5_logit_contributions"][k], + bottom5_token_ids=kwargs["bottom5_token_ids"][k], + bottom5_logit_contributions=kwargs["bottom5_logit_contributions"][k], ) for k in range(len(kwargs["token_ids"])) ] def __getitem__(self, idx: int) -> SequenceData: return self.seqs[idx] - + def __len__(self) -> int: return len(self.seqs) - + def __str__(self) -> str: return "\n".join([str(seq) for seq in self.seqs]) - @dataclass class FeatureData: - ''' + """ Class to store all data for a feature that will be used in the visualization. Also has a bunch of methods to create visualisations. So this is the main important class. The biggest arg is `sequence_data`, it's explained below. The other args are individual, and are used to construct the left-hand visualisations. - + Args for the right-hand sequences: sequence_data: Dict[str, SequenceDataBatch] @@ -199,9 +209,9 @@ class FeatureData: Each key is a group name (there are 12 in total: top, bottom, 10 quantiles), and each value is a SequenceDataBatch object (i.e. it contains a batch of SequenceData objects, one for each sequence in the group). See these classes for more on how these are used. - + Args for the middle column: - + top10_logits: Tuple[TopK, TopK] Contains the most neg / pos 10 logits, used for the logits table @@ -215,17 +225,17 @@ class FeatureData: Also used for frequencies histogram, this is the fraction of activations which are non-zero Args for the left-hand column - + neuron_alignment: Tuple[TopK, Tensor] first element is the topk aligned neurons (i.e. argmax on decoder weights) second element is the fraction of L1 norm this neuron makes up, in this decoder weight vector. - + neurons_correlated: Tuple[TopK, TopK] the topk neurons most correlated with each other, i.e. this feature has (N,) activations and the neurons have (d_mlp, N) activations on these tokens, where N = batch_size * seq_len, and - we find the neuron (column of second tensor) with highest correlation. Contains Pearson & + we find the neuron (column of second tensor) with highest correlation. Contains Pearson & Cosine sim (difference is that Pearson centers weights first). - + b_features_correlated: Tuple[TopK, TopK] same datatype as neurons_correlated, but now we're looking at this feature's (N,) activations and comparing them to the (h, N) activations of the encoder-B features (where h is the hidden @@ -235,7 +245,7 @@ class FeatureData: model: HookedTransformer The thing you're actually doing forward passes through, and finding features of - + encoder: AutoEncoder The encoder of the model, which you're using to find features @@ -245,9 +255,10 @@ class FeatureData: n_groups, first_group_size, other_groups_size All params to determine size of the sequences in right hand of visualisation. - ''' + """ + sequence_data: Dict[str, SequenceDataBatch] - + top10_logits: Tuple[TopK, TopK] logits_histogram_data: HistogramData frequencies_histogram_data: HistogramData @@ -263,25 +274,26 @@ class FeatureData: first_group_size: int = 20 other_groups_size: int = 5 - def return_save_dict(self) -> dict: - '''Returns a dict we use for saving (pickling).''' - return { - k: v for k, v in self.__dict__.items() - if k not in ["vocab_dict"] - } - + """Returns a dict we use for saving (pickling).""" + return {k: v for k, v in self.__dict__.items() if k not in ["vocab_dict"]} @classmethod def load_from_save_dict(self, save_dict, vocab_dict): - '''Loads this object from a dict (e.g. from a pickle file).''' + """Loads this object from a dict (e.g. from a pickle file).""" return FeatureData(**save_dict, vocab_dict=vocab_dict) - @classmethod - def save_batch(cls, batch: Dict[int, "FeatureData"], filename: str, save_type: Literal["pkl", "gzip"]) -> None: - '''Saves a batch of FeatureData objects to a pickle file.''' - assert "." not in filename, "You should pass in the filename without the extension." + def save_batch( + cls, + batch: Dict[int, "FeatureData"], + filename: str, + save_type: Literal["pkl", "gzip"], + ) -> None: + """Saves a batch of FeatureData objects to a pickle file.""" + assert ( + "." not in filename + ), "You should pass in the filename without the extension." filename = filename + ".pkl" if (save_type == "pkl") else filename + ".pkl.gz" save_obj = {k: v.return_save_dict() for k, v in batch.items()} if save_type == "pkl": @@ -291,13 +303,22 @@ def save_batch(cls, batch: Dict[int, "FeatureData"], filename: str, save_type: L with gzip.open(filename, "wb") as f: pickle.dump(save_obj, f) return filename - @classmethod - def load_batch(cls, filename: str, save_type: Literal["pkl", "gzip"], vocab_dict: Dict[int, str], feature_idx: Optional[int] = None) -> Union["FeatureData", Dict[int, "FeatureData"]]: - '''Loads a batch of FeatureData objects from a pickle file.''' - assert "." not in filename, "You should pass in the filename without the extension." - filename = filename + ".pkl" if save_type.startswith("pkl") else filename + ".pkl.gz" + def load_batch( + cls, + filename: str, + save_type: Literal["pkl", "gzip"], + vocab_dict: Dict[int, str], + feature_idx: Optional[int] = None, + ) -> Union["FeatureData", Dict[int, "FeatureData"]]: + """Loads a batch of FeatureData objects from a pickle file.""" + assert ( + "." not in filename + ), "You should pass in the filename without the extension." + filename = ( + filename + ".pkl" if save_type.startswith("pkl") else filename + ".pkl.gz" + ) if save_type.startswith("pkl"): with open(filename, "rb") as f: save_obj = pickle.load(f) @@ -306,14 +327,18 @@ def load_batch(cls, filename: str, save_type: Literal["pkl", "gzip"], vocab_dict save_obj = pickle.load(f) if feature_idx is None: - return {k: FeatureData.load_from_save_dict(v, vocab_dict) for k, v in save_obj.items()} - else: + return { + k: FeatureData.load_from_save_dict(v, vocab_dict) + for k, v in save_obj.items() + } + else: return FeatureData.load_from_save_dict(save_obj[feature_idx], vocab_dict) - def save(self, filename: str, save_type: Literal["pkl", "gzip"]) -> None: - '''Saves this object to a pickle file (we don't need to save the model and encoder too, just the data).''' - assert "." not in filename, "You should pass in the filename without the extension." + """Saves this object to a pickle file (we don't need to save the model and encoder too, just the data).""" + assert ( + "." not in filename + ), "You should pass in the filename without the extension." filename = filename + ".pkl" if (save_type == "pkl") else filename + ".pkl.gz" save_obj = self.return_save_dict() if save_type.startswith("pkl"): @@ -324,37 +349,35 @@ def save(self, filename: str, save_type: Literal["pkl", "gzip"]) -> None: pickle.dump(save_obj, f) return filename - def __str__(self) -> str: num_sequences = sum([len(batch) for batch in self.sequence_data.values()]) return f"FeatureData(num_sequences={num_sequences})" - def get_sequences_html(self) -> str: - sequences_html_dict = {} for group_name, sequences in self.sequence_data.items(): - - full_html = f'

{group_name}

' # style="padding-left:25px;" - + full_html = f"

{group_name}

" # style="padding-left:25px;" + for seq in sequences: html_output = generate_seq_html( self.vocab_dict, - token_ids = seq.token_ids, - feat_acts = seq.feat_acts, - contribution_to_loss = seq.contribution_to_loss, - bold_idx = self.buffer[0], # e.g. the 6th item, with index 5, if buffer=(5, 5) - is_repeat = seq.repeat, - pos_ids = seq.top5_token_ids, - neg_ids = seq.bottom5_token_ids, - pos_val = seq.top5_logit_contributions, - neg_val = seq.bottom5_logit_contributions, + token_ids=seq.token_ids, + feat_acts=seq.feat_acts, + contribution_to_loss=seq.contribution_to_loss, + bold_idx=self.buffer[ + 0 + ], # e.g. the 6th item, with index 5, if buffer=(5, 5) + is_repeat=seq.repeat, + pos_ids=seq.top5_token_ids, + neg_ids=seq.bottom5_token_ids, + pos_val=seq.top5_logit_contributions, + neg_val=seq.bottom5_logit_contributions, ) full_html += html_output - + sequences_html_dict[group_name] = full_html - + # Now, wrap all the values of this dictionary into grid-items: (top, groups of 3 for middle, bottom) html_top, html_bottom, *html_sampled = sequences_html_dict.values() sequences_html = "" @@ -367,53 +390,52 @@ def get_sequences_html(self) -> str: return sequences_html + HTML_HOVERTEXT_SCRIPT - def get_tables_html(self) -> Tuple[str, str]: - bottom10_logits, top10_logits = self.top10_logits # Get the negative and positive background values (darkest when equals max abs). Easier when in tensor form - max_value = max(np.absolute(bottom10_logits.values).max(), np.absolute(top10_logits.values).max()) + max_value = max( + np.absolute(bottom10_logits.values).max(), + np.absolute(top10_logits.values).max(), + ) neg_bg_values = np.absolute(bottom10_logits.values) / max_value pos_bg_values = np.absolute(top10_logits.values) / max_value - + # Generate the html left_tables_html, logit_tables_html = generate_tables_html( - neuron_alignment_indices = self.neuron_alignment[0].indices.tolist(), - neuron_alignment_values = self.neuron_alignment[0].values.tolist(), - neuron_alignment_l1 = self.neuron_alignment[1].tolist(), - correlated_neurons_indices = self.neurons_correlated[0].indices.tolist(), - correlated_neurons_pearson = self.neurons_correlated[0].values.tolist(), - correlated_neurons_l1 = self.neurons_correlated[1].values.tolist(), - correlated_features_indices = None, #self.b_features_correlated[0].indices.tolist(), - correlated_features_pearson = None,#self.b_features_correlated[0].values.tolist(), - correlated_features_l1 = None,#self.b_features_correlated[1].values.tolist(), - + neuron_alignment_indices=self.neuron_alignment[0].indices.tolist(), + neuron_alignment_values=self.neuron_alignment[0].values.tolist(), + neuron_alignment_l1=self.neuron_alignment[1].tolist(), + correlated_neurons_indices=self.neurons_correlated[0].indices.tolist(), + correlated_neurons_pearson=self.neurons_correlated[0].values.tolist(), + correlated_neurons_l1=self.neurons_correlated[1].values.tolist(), + correlated_features_indices=None, # self.b_features_correlated[0].indices.tolist(), + correlated_features_pearson=None, # self.b_features_correlated[0].values.tolist(), + correlated_features_l1=None, # self.b_features_correlated[1].values.tolist(), neg_str=to_str_tokens(self.vocab_dict, bottom10_logits.indices), neg_values=bottom10_logits.values.tolist(), neg_bg_values=neg_bg_values, pos_str=to_str_tokens(self.vocab_dict, top10_logits.indices), pos_values=top10_logits.values.tolist(), - pos_bg_values=pos_bg_values + pos_bg_values=pos_bg_values, ) # Return both items (we'll be wrapping them in 'grid-item' later) return left_tables_html, logit_tables_html - def get_histograms(self) -> Tuple[str, str]: - ''' + """ From the histogram data, returns the actual histogram HTML strings. - ''' - frequencies_histogram, logits_histogram = generate_histograms(self.frequencies_histogram_data, self.logits_histogram_data) + """ + frequencies_histogram, logits_histogram = generate_histograms( + self.frequencies_histogram_data, self.logits_histogram_data + ) return ( f"

ACTIVATIONS
DENSITY = {self.frac_nonzero:.3%}

{frequencies_histogram}
", - f"
{logits_histogram}
" + f"
{logits_histogram}
", ) - def get_all_html(self, debug: bool = False, split_scripts: bool = False) -> str: - # Get the individual HTML left_tables_html, logit_tables_html = self.get_tables_html() sequences_html = self.get_sequences_html() @@ -439,7 +461,7 @@ def get_all_html(self, debug: bool = False, split_scripts: bool = False) -> str: """ # idk why this bug is here, for representing newlines the wrong way html_string = html_string.replace("Ċ", "\n") - + if debug: display(HTML(html_string)) @@ -448,14 +470,10 @@ def get_all_html(self, debug: bool = False, split_scripts: bool = False) -> str: return scripts, html_string else: return html_string - - - - class BatchedCorrCoef: - ''' + """ This class allows me to calculate corrcoef (both Pearson and cosine sim) between two batches of vectors without needing to store them all in memory. @@ -472,7 +490,8 @@ class BatchedCorrCoef: denom = (n * x2_sum - x_sum ** 2) ** 0.5 * (n * y2_sum - y_sum ** 2) ** 0.5 ...and all these quantities (x_sum, xy_sum, etc) can be tracked on a rolling basis. - ''' + """ + def __init__(self): self.n = 0 self.x_sum = 0 @@ -481,34 +500,44 @@ def __init__(self): self.x2_sum = 0 self.y2_sum = 0 - def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]): + def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]): # noqa assert x.ndim == 2 and y.ndim == 2, "Both x and y should be 2D" - assert x.shape[-1] == y.shape[-1], "x and y should have the same size in the last dimension" - + assert ( + x.shape[-1] == y.shape[-1] + ), "x and y should have the same size in the last dimension" + self.n += x.shape[-1] self.x_sum += einops.reduce(x, "X N -> X", "sum") self.y_sum += einops.reduce(y, "Y N -> Y", "sum") self.xy_sum += einops.einsum(x, y, "X N, Y N -> X Y") - self.x2_sum += einops.reduce(x ** 2, "X N -> X", "sum") - self.y2_sum += einops.reduce(y ** 2, "Y N -> Y", "sum") + self.x2_sum += einops.reduce(x**2, "X N -> X", "sum") + self.y2_sum += einops.reduce(y**2, "Y N -> Y", "sum") - def corrcoef(self) -> Tuple[Float[Tensor, "X Y"], Float[Tensor, "X Y"]]: + def corrcoef(self) -> Tuple[Float[Tensor, "X Y"], Float[Tensor, "X Y"]]: # noqa cossim_numer = self.xy_sum cossim_denom = torch.sqrt(torch.outer(self.x2_sum, self.y2_sum)) + 1e-6 cossim = cossim_numer / cossim_denom pearson_numer = self.n * self.xy_sum - torch.outer(self.x_sum, self.y_sum) - pearson_denom = torch.sqrt(torch.outer(self.n * self.x2_sum - self.x_sum ** 2, self.n * self.y2_sum - self.y_sum ** 2)) + 1e-6 + pearson_denom = ( + torch.sqrt( + torch.outer( + self.n * self.x2_sum - self.x_sum**2, + self.n * self.y2_sum - self.y_sum**2, + ) + ) + + 1e-6 + ) pearson = pearson_numer / pearson_denom return pearson, cossim def topk(self, k: int, largest: bool = True) -> Tuple[TopK, TopK]: - '''Returns the topk corrcoefs, using Pearson (and taking this over the y-tensor)''' + """Returns the topk corrcoefs, using Pearson (and taking this over the y-tensor)""" pearson, cossim = self.corrcoef() X, Y = cossim.shape # Get pearson topk by actually taking topk - pearson_topk = TopK(pearson.topk(dim=-1, k=k, largest=largest)) # shape (X, k) + pearson_topk = TopK(pearson.topk(dim=-1, k=k, largest=largest)) # shape (X, k) # Get cossim topk by indexing into cossim with the indices of the pearson topk: cossim[X, pearson_indices[X, k]] cossim_values = eindex(cossim, pearson_topk.indices, "X [X k]") cossim_topk = TopK((cossim_values, pearson_topk.indices)) @@ -523,19 +552,17 @@ def get_feature_data( hook_point: str, hook_point_layer: int, hook_point_head_index: Optional[int], - tokens: Int[Tensor, "batch seq"], + tokens: Int[Tensor, "batch seq"], # noqa feature_idx: Union[int, List[int]], max_batch_size: Optional[int] = None, - left_hand_k: int = 3, buffer: Tuple[int, int] = (5, 5), n_groups: int = 10, first_group_size: int = 20, other_groups_size: int = 5, verbose: bool = False, - ) -> Dict[int, FeatureData]: - ''' + """ Gets data that will be used to create the sequences in the HTML visualisation. Args: @@ -551,14 +578,15 @@ def get_feature_data( The number of tokens on either side of the feature, for the right-hand visualisation. Returns object of class FeatureData (see that class's docstring for more info). - ''' + """ t0 = time.time() model.reset_hooks(including_permanent=True) device = model.cfg.device # Make feature_idx a list, for convenience - if isinstance(feature_idx, int): feature_idx = [feature_idx] + if isinstance(feature_idx, int): + feature_idx = [feature_idx] n_feats = len(feature_idx) # Chunk the tokens, for less memory usage @@ -574,25 +602,35 @@ def get_feature_data( # corrcoef_encoder_B = BatchedCorrCoef() # Get encoder & decoder directions - feature_act_dir = encoder.W_enc[:, feature_idx] # (d_in, feats) - feature_bias = encoder.b_enc[feature_idx] # (feats,) - feature_out_dir = encoder.W_dec[feature_idx] # (feats, d_in) - + feature_act_dir = encoder.W_enc[:, feature_idx] # (d_in, feats) + feature_bias = encoder.b_enc[feature_idx] # (feats,) + feature_out_dir = encoder.W_dec[feature_idx] # (feats, d_in) + if "resid_pre" in hook_point: - feature_mlp_out_dir = feature_out_dir # (feats, d_model) + feature_mlp_out_dir = feature_out_dir # (feats, d_model) elif "resid_post" in hook_point: - feature_mlp_out_dir = feature_out_dir @ model.W_out[hook_point_layer] # (feats, d_model) + feature_mlp_out_dir = ( + feature_out_dir @ model.W_out[hook_point_layer] + ) # (feats, d_model) elif "hook_q" in hook_point: # unembed proj onto residual stream - feature_mlp_out_dir = feature_out_dir @ model.W_Q[hook_point_layer, hook_point_head_index].T # (feats, d_model)ß - assert feature_act_dir.T.shape == feature_out_dir.shape == (len(feature_idx), encoder.cfg.d_in) + feature_mlp_out_dir = ( + feature_out_dir @ model.W_Q[hook_point_layer, hook_point_head_index].T + ) # (feats, d_model)ß + assert ( + feature_act_dir.T.shape + == feature_out_dir.shape + == (len(feature_idx), encoder.cfg.d_in) + ) t1 = time.time() # ! Define hook function to perform feature ablation - def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint): - ''' + def hook_fn_act_post( + act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint # noqa + ): # noqa + """ Encoder has learned x^j \approx b + \sum_i f_i(x^j)d_i where: - f_i are the feature activations - d_i are the feature output directions @@ -601,10 +639,12 @@ def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint if we did this, then we'd have to run a different fwd pass for every feature, which is super wasteful! But later, we'll calculate the effect of feature ablation, i.e. x^j <- x^j - f_i(x^j)d_i for i = feature_idx, only on the tokens we care about (the ones which will appear in the visualisation). - ''' + """ # Calculate & store the feature activations (we need to store them so we can get the right-hand visualisations later) x_cent = act_post - encoder.b_dec - feat_acts_pre = einops.einsum(x_cent, feature_act_dir, "batch seq d_mlp, d_mlp feats -> batch seq feats") + feat_acts_pre = einops.einsum( + x_cent, feature_act_dir, "batch seq d_mlp, d_mlp feats -> batch seq feats" + ) feat_acts = F.relu(feat_acts_pre + feature_bias) all_feat_acts.append(feat_acts) @@ -613,7 +653,7 @@ def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint einops.rearrange(feat_acts, "batch seq feats -> feats (batch seq)"), einops.rearrange(act_post, "batch seq d_mlp -> d_mlp (batch seq)"), ) - + # Calculate encoder-B feature activations (we don't need to store them, cause it's just for the left-hand visualisations) # x_cent_B = act_post - encoder_B.b_dec # feat_acts_pre_B = einops.einsum(x_cent_B, encoder_B.W_enc, "batch seq d_mlp, d_mlp d_hidden -> batch seq d_hidden") @@ -624,11 +664,13 @@ def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint # einops.rearrange(feat_acts, "batch seq feats -> feats (batch seq)"), # einops.rearrange(feat_acts_B, "batch seq d_hidden -> d_hidden (batch seq)"), # ) - - def hook_fn_query(hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPoint): - ''' - - Replace act_post with projection of query onto the resid by W_k^T. + + def hook_fn_query( + hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPoint # noqa + ): + """ + + Replace act_post with projection of query onto the resid by W_k^T. Encoder has learned x^j \approx b + \sum_i f_i(x^j)d_i where: - f_i are the feature activations - d_i are the feature output directions @@ -637,75 +679,105 @@ def hook_fn_query(hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPo if we did this, then we'd have to run a different fwd pass for every feature, which is super wasteful! But later, we'll calculate the effect of feature ablation, i.e. x^j <- x^j - f_i(x^j)d_i for i = feature_idx, only on the tokens we care about (the ones which will appear in the visualisation). - ''' + """ # Calculate & store the feature activations (we need to store them so we can get the right-hand visualisations later) hook_q = hook_q[:, :, hook_point_head_index] x_cent = hook_q - encoder.b_dec - feat_acts_pre = einops.einsum(x_cent, feature_act_dir, "batch seq d_mlp, d_mlp feats -> batch seq feats") + feat_acts_pre = einops.einsum( + x_cent, feature_act_dir, "batch seq d_mlp, d_mlp feats -> batch seq feats" + ) feat_acts = F.relu(feat_acts_pre + feature_bias) all_feat_acts.append(feat_acts) - - # project this back up to resid stream size. + + # project this back up to resid stream size. act_resid_proj = hook_q @ model.W_Q[hook_point_layer, hook_point_head_index].T # Update the CorrCoef object between feature activation & neurons corrcoef_neurons.update( einops.rearrange(feat_acts, "batch seq feats -> feats (batch seq)"), - einops.rearrange(act_resid_proj, "batch seq d_model -> d_model (batch seq)"), + einops.rearrange( + act_resid_proj, "batch seq d_model -> d_model (batch seq)" + ), ) - - def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: HookPoint): - ''' + def hook_fn_resid_post( + resid_post: Float[Tensor, "batch seq d_model"], hook: HookPoint # noqa + ): + """ This hook function stores the residual activations, which we'll need later on to calculate the effect of feature ablation. - ''' + """ all_resid_post.append(resid_post) - # Run the model without hook (to store all the information we need, not to actually return anything) - + # ! Run the forward passes (triggering the hooks), concat all results iterator = tqdm(all_tokens, desc="Storing model activations") if "resid_pre" in hook_point: for _tokens in iterator: - model.run_with_hooks(_tokens, return_type=None, fwd_hooks=[ - (hook_point, hook_fn_act_post), - (utils.get_act_name("resid_pre", hook_point_layer), hook_fn_resid_post) - ]) - # If we are using MLP activations, then we'd want this one. + model.run_with_hooks( + _tokens, + return_type=None, + fwd_hooks=[ + (hook_point, hook_fn_act_post), + ( + utils.get_act_name("resid_pre", hook_point_layer), + hook_fn_resid_post, + ), + ], + ) + # If we are using MLP activations, then we'd want this one. elif "resid_post" in hook_point: for _tokens in iterator: - - model.run_with_hooks(_tokens, return_type=None, fwd_hooks=[ - (utils.get_act_name("post", hook_point_layer), hook_fn_act_post), - (utils.get_act_name("resid_post", hook_point_layer), hook_fn_resid_post) - ]) + model.run_with_hooks( + _tokens, + return_type=None, + fwd_hooks=[ + (utils.get_act_name("post", hook_point_layer), hook_fn_act_post), + ( + utils.get_act_name("resid_post", hook_point_layer), + hook_fn_resid_post, + ), + ], + ) elif "hook_q" in hook_point: iterator = tqdm(all_tokens, desc="Storing model activations") for _tokens in iterator: - model.run_with_hooks(_tokens, return_type=None, fwd_hooks=[ - (hook_point, hook_fn_query), - (utils.get_act_name("resid_post", hook_point_layer), hook_fn_resid_post) - ]) - - + model.run_with_hooks( + _tokens, + return_type=None, + fwd_hooks=[ + (hook_point, hook_fn_query), + ( + utils.get_act_name("resid_post", hook_point_layer), + hook_fn_resid_post, + ), + ], + ) + t2 = time.time() # Stack the results, and check shapes (remember that we don't get loss for the last token) - feat_acts = torch.concatenate(all_feat_acts) # [batch seq feats] - resid_post = torch.concatenate(all_resid_post) # [batch seq d_model] + feat_acts = torch.concatenate(all_feat_acts) # [batch seq feats] + resid_post = torch.concatenate(all_resid_post) # [batch seq d_model] assert feat_acts[:, :-1].shape == tokens[:, :-1].shape + (len(feature_idx),) t3 = time.time() - - # ! Calculate all data for the left-hand column visualisations, i.e. the 3 size-3 tables # First, get the logits of this feature - logits = einops.einsum(feature_mlp_out_dir, model.W_U, "feats d_model, d_model d_vocab -> feats d_vocab") + logits = einops.einsum( + feature_mlp_out_dir, + model.W_U, + "feats d_model, d_model d_vocab -> feats d_vocab", + ) # Second, get the neurons most aligned with this feature (based on output weights) - top3_neurons_aligned = TopK(feature_out_dir.topk(dim=-1, k=left_hand_k, largest=True)) - pct_of_l1 = np.absolute(top3_neurons_aligned.values) / feature_out_dir.abs().sum(dim=-1, keepdim=True).cpu().numpy() + top3_neurons_aligned = TopK( + feature_out_dir.topk(dim=-1, k=left_hand_k, largest=True) + ) + pct_of_l1 = ( + np.absolute(top3_neurons_aligned.values) + / feature_out_dir.abs().sum(dim=-1, keepdim=True).cpu().numpy() + ) # Third, get the neurons most correlated with this feature (based on input weights) top_correlations_neurons = corrcoef_neurons.topk(k=left_hand_k, largest=True) # Lastly, get most correlated weights in B features @@ -713,8 +785,6 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo t4 = time.time() - - # ! Calculate all data for the right-hand visualisations, i.e. the sequences # TODO - parallelize this (it could probably be sped up by batching indices & doing all sequences at once, although those would be large tensors) # We do this in 2 steps: @@ -725,24 +795,35 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo sequence_data_list = [] - iterator = range(n_feats) if not(verbose) else tqdm(range(n_feats), desc="Getting sequence data", leave=False) + iterator = ( + range(n_feats) + if not (verbose) + else tqdm(range(n_feats), desc="Getting sequence data", leave=False) + ) for feat in iterator: + _feat_acts = feat_acts[..., feat] # [batch seq] - _feat_acts = feat_acts[..., feat] # [batch seq] - # (1) indices_dict = { - f"TOP ACTIVATIONS
MAX = {_feat_acts.max():.3f}": k_largest_indices(_feat_acts, k=first_group_size, largest=True), - f"BOTTOM ACTIVATIONS
MIN = {_feat_acts.min():.3f}": k_largest_indices(_feat_acts, k=first_group_size, largest=False), + f"TOP ACTIVATIONS
MAX = {_feat_acts.max():.3f}": k_largest_indices( + _feat_acts, k=first_group_size, largest=True + ), + f"BOTTOM ACTIVATIONS
MIN = {_feat_acts.min():.3f}": k_largest_indices( + _feat_acts, k=first_group_size, largest=False + ), } - quantiles = torch.linspace(0, _feat_acts.max(), n_groups+1) - for i in range(n_groups-1, -1, -1): - lower, upper = quantiles[i:i+2] + quantiles = torch.linspace(0, _feat_acts.max(), n_groups + 1) + for i in range(n_groups - 1, -1, -1): + lower, upper = quantiles[i : i + 2] pct = ((_feat_acts >= lower) & (_feat_acts <= upper)).float().mean() - indices = random_range_indices(_feat_acts, (lower, upper), k=other_groups_size) - indices_dict[f"INTERVAL {lower:.3f} - {upper:.3f}
CONTAINS {pct:.3%}"] = indices + indices = random_range_indices( + _feat_acts, (lower, upper), k=other_groups_size + ) + indices_dict[ + f"INTERVAL {lower:.3f} - {upper:.3f}
CONTAINS {pct:.3%}" + ] = indices # Concat all the indices together (in the next steps we do all groups at once) indices_full = torch.concat(list(indices_dict.values())) @@ -753,35 +834,59 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo # i.e. indices[..., 0] = shape (g, buf) contains the batch indices of the sequences, and indices[..., 1] = contains seq indices # (B) index into all our tensors to get the relevant data (this includes calculating the effect of ablation) # (C) construct the SequenceData objects, in the form of a SequenceDataBatch object - + # (A) # For each token index [batch, seq], we actually want [[batch, seq-buffer[0]], ..., [batch, seq], ..., [batch, seq+buffer[1]]] # We get one extra dimension at the start, because we need to see the effect on loss of the first token - buffer_tensor = torch.arange(-buffer[0] - 1, buffer[1] + 1, device=indices_full.device) - indices_full = einops.repeat(indices_full, "g two -> g buf two", buf=buffer[0] + buffer[1] + 2) - indices_full = torch.stack([indices_full[..., 0], indices_full[..., 1] + buffer_tensor], dim=-1).cpu() + buffer_tensor = torch.arange( + -buffer[0] - 1, buffer[1] + 1, device=indices_full.device + ) + indices_full = einops.repeat( + indices_full, "g two -> g buf two", buf=buffer[0] + buffer[1] + 2 + ) + indices_full = torch.stack( + [indices_full[..., 0], indices_full[..., 1] + buffer_tensor], dim=-1 + ).cpu() # (B) # Template for indexing is new_tensor[k, seq] = tensor[indices_full[k, seq, 1], indices_full[k, seq, 2]], sometimes there's an extra dim at the end tokens_group = eindex(tokens, indices_full[:, 1:], "[g buf 0] [g buf 1]") feat_acts_group = eindex(_feat_acts, indices_full, "[g buf 0] [g buf 1]") - resid_post_group = eindex(resid_post, indices_full, "[g buf 0] [g buf 1] d_model") + resid_post_group = eindex( + resid_post, indices_full, "[g buf 0] [g buf 1] d_model" + ) # From these feature activations, get the actual contribution to the final value of the residual stream - resid_post_feature_effect = einops.einsum(feat_acts_group, feature_mlp_out_dir[feat], "g buf, d_model -> g buf d_model") + resid_post_feature_effect = einops.einsum( + feat_acts_group, + feature_mlp_out_dir[feat], + "g buf, d_model -> g buf d_model", + ) # Get the resulting new logits (by subtracting this effect from resid_post, then applying layernorm & unembedding) new_resid_post = resid_post_group - resid_post_feature_effect - new_logits = (new_resid_post / new_resid_post.std(dim=-1, keepdim=True)) @ model.W_U - orig_logits = (resid_post_group / resid_post_group.std(dim=-1, keepdim=True)) @ model.W_U + new_logits = ( + new_resid_post / new_resid_post.std(dim=-1, keepdim=True) + ) @ model.W_U + orig_logits = ( + resid_post_group / resid_post_group.std(dim=-1, keepdim=True) + ) @ model.W_U # Get the top5 & bottom5 changes in logits # note - changes in logits are for hovering over predict-ING token, so it should align w/ tokens_group, hence we slice [:, 1:] - contribution_to_logprobs = orig_logits.log_softmax(dim=-1) - new_logits.log_softmax(dim=-1) - top5_contribution_to_logits = TopK(contribution_to_logprobs[:, :-1].topk(k=5, largest=True)) - bottom5_contribution_to_logits = TopK(contribution_to_logprobs[:, :-1].topk(k=5, largest=False)) + contribution_to_logprobs = orig_logits.log_softmax( + dim=-1 + ) - new_logits.log_softmax(dim=-1) + top5_contribution_to_logits = TopK( + contribution_to_logprobs[:, :-1].topk(k=5, largest=True) + ) + bottom5_contribution_to_logits = TopK( + contribution_to_logprobs[:, :-1].topk(k=5, largest=False) + ) # Get the change in loss (which is negative of change of logprobs for correct token) # note - changes in loss are for underlining predict-ED token, hence we slice [:, :-1] - contribution_to_loss = eindex(-contribution_to_logprobs[:, :-1], tokens_group, "g buf [g buf]") + contribution_to_loss = eindex( + -contribution_to_logprobs[:, :-1], tokens_group, "g buf [g buf]" + ) # (C) # Now that we've indexed everything, construct the batch of SequenceData objects @@ -790,14 +895,22 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo for group_name, indices in indices_dict.items(): lower, upper = g_total, g_total + len(indices) sequence_data[group_name] = SequenceDataBatch( - token_ids=tokens_group[lower: upper].tolist(), - feat_acts=feat_acts_group[lower: upper, 1:].tolist(), - contribution_to_loss=contribution_to_loss[lower: upper].tolist(), + token_ids=tokens_group[lower:upper].tolist(), + feat_acts=feat_acts_group[lower:upper, 1:].tolist(), + contribution_to_loss=contribution_to_loss[lower:upper].tolist(), repeat=False, - top5_token_ids=top5_contribution_to_logits.indices[lower: upper].tolist(), - top5_logit_contributions=top5_contribution_to_logits.values[lower: upper].tolist(), - bottom5_token_ids=bottom5_contribution_to_logits.indices[lower: upper].tolist(), - bottom5_logit_contributions=bottom5_contribution_to_logits.values[lower: upper].tolist(), + top5_token_ids=top5_contribution_to_logits.indices[ + lower:upper + ].tolist(), + top5_logit_contributions=top5_contribution_to_logits.values[ + lower:upper + ].tolist(), + bottom5_token_ids=bottom5_contribution_to_logits.indices[ + lower:upper + ].tolist(), + bottom5_logit_contributions=bottom5_contribution_to_logits.values[ + lower:upper + ].tolist(), ) g_total += len(indices) @@ -806,7 +919,6 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo t5 = time.time() - # ! Get all data for the middle column visualisations, i.e. the two histograms & the logit table nonzero_feat_acts = [] @@ -816,7 +928,6 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo logits_histogram_data = [] for feat in range(n_feats): - _logits = logits[feat] # Get data for logits histogram @@ -824,39 +935,43 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo logits_histogram_data.append(HistogramData(_logits, n_bins=40, tickmode="ints")) # Get data for logits table - top10_logits.append((TopK(_logits.topk(k=10, largest=False)), TopK(_logits.topk(k=10)))) + top10_logits.append( + (TopK(_logits.topk(k=10, largest=False)), TopK(_logits.topk(k=10))) + ) # Get data for feature activations histogram _feat_acts = feat_acts[..., feat] nonzero_feat_acts = _feat_acts[_feat_acts > 0] frac_nonzero.append(nonzero_feat_acts.numel() / _feat_acts.numel()) - frequencies_histogram_data.append(HistogramData(nonzero_feat_acts, n_bins=40, tickmode="ints")) + frequencies_histogram_data.append( + HistogramData(nonzero_feat_acts, n_bins=40, tickmode="ints") + ) t6 = time.time() - # ! Return the output, as a dict of FeatureData items vocab_dict = model.tokenizer.vocab - vocab_dict = {v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()} + vocab_dict = { + v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items() + } return_obj = { feat: FeatureData( - # For right-hand sequences sequence_data=sequence_data_list[i], - # For middle column (logits table, and both histograms) top10_logits=top10_logits[i], logits_histogram_data=logits_histogram_data[i], frequencies_histogram_data=frequencies_histogram_data[i], frac_nonzero=frac_nonzero[i], - # For left column, i.e. the 3 tables of size 3 neuron_alignment=(top3_neurons_aligned[i], pct_of_l1[i]), - neurons_correlated=(top_correlations_neurons[0][i], top_correlations_neurons[1][i]), - b_features_correlated=None,#(top_correlations_encoder_B[0][i], top_correlations_encoder_B[1][i]), - + neurons_correlated=( + top_correlations_neurons[0][i], + top_correlations_neurons[1][i], + ), + b_features_correlated=None, # (top_correlations_encoder_B[0][i], top_correlations_encoder_B[1][i]), # Other stuff (not containing data) vocab_dict=vocab_dict, buffer=buffer, @@ -867,41 +982,55 @@ def hook_fn_resid_post(resid_post: Float[Tensor, "batch seq d_model"], hook: Hoo for i, feat in enumerate(feature_idx) } - # ! If verbose, try to estimate time it will take to generate data for all features, plus storage space if verbose: - n_feats_total = encoder.cfg.d_sae # Get time total_time = t5 - t0 table = Table("Task", "Time", "Pct %", title="Time taken for each task") for task, _time in zip( - ["Setup code", "Fwd passes", "Concats", "Left-hand tables", "Right-hand sequences", "Middle column"], - [t1-t0, t2-t1, t3-t2, t4-t3, t5-t4, t6-t5] + [ + "Setup code", + "Fwd passes", + "Concats", + "Left-hand tables", + "Right-hand sequences", + "Middle column", + ], + [t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4, t6 - t5], ): frac = _time / total_time table.add_row(task, f"{_time:.2f}s", f"{frac:.1%}") rprint(table) - est = ((t3 - t0) + (n_feats_total / n_feats) * (t6 - t4) / 60) + est = (t3 - t0) + (n_feats_total / n_feats) * (t6 - t4) / 60 print(f"Estimated time for all {n_feats_total} features = {est:.0f} minutes\n") # Get filesizes, for different methods of saving batch_size = 50 if n_feats >= batch_size: - print(f"Estimated filesize of all {n_feats_total} features if saved in groups of batch_size, with save type...") - save_obj = {k: v for k, v in return_obj.items() if k in feature_idx[:batch_size]} + print( + f"Estimated filesize of all {n_feats_total} features if saved in groups of batch_size, with save type..." + ) + save_obj = { + k: v for k, v in return_obj.items() if k in feature_idx[:batch_size] + } filename = str(Path(__file__).parent.resolve() / "temp") for save_type in ["pkl", "gzip"]: t0 = time.time() - full_filename = FeatureData.save_batch(save_obj, filename=filename, save_type=save_type) + full_filename = FeatureData.save_batch( + save_obj, filename=filename, save_type=save_type + ) t1 = time.time() - loaded_obj = FeatureData.load_batch(filename, save_type=save_type, vocab_dict=vocab_dict) + loaded_obj = FeatureData.load_batch( + filename, save_type=save_type, vocab_dict=vocab_dict + ) t2 = time.time() filesize = os.path.getsize(full_filename) / 1e6 - print(f"{save_type:>5} = {filesize * n_feats_total / batch_size:>5.1f} MB, save time = {t1-t0:.3f}s, load time = {t2-t1:.3f}s") + print( + f"{save_type:>5} = {filesize * n_feats_total / batch_size:>5.1f} MB, save time = {t1-t0:.3f}s, load time = {t2-t1:.3f}s" + ) os.remove(full_filename) return return_obj - diff --git a/sae_analysis/visualizer/html/frequency_histogram.html b/sae_analysis/visualizer/html/frequency_histogram.html index 27baa464..2d5d17ef 100644 --- a/sae_analysis/visualizer/html/frequency_histogram.html +++ b/sae_analysis/visualizer/html/frequency_histogram.html @@ -2,4 +2,4 @@
- \ No newline at end of file + diff --git a/sae_analysis/visualizer/html/hovertext_script.html b/sae_analysis/visualizer/html/hovertext_script.html index 70d682be..642da429 100644 --- a/sae_analysis/visualizer/html/hovertext_script.html +++ b/sae_analysis/visualizer/html/hovertext_script.html @@ -23,4 +23,4 @@ }); }); - \ No newline at end of file + diff --git a/sae_analysis/visualizer/html/logit_table_template.html b/sae_analysis/visualizer/html/logit_table_template.html index fad88115..cba07816 100644 --- a/sae_analysis/visualizer/html/logit_table_template.html +++ b/sae_analysis/visualizer/html/logit_table_template.html @@ -38,4 +38,4 @@

POSITIVE LOGITS

- \ No newline at end of file + diff --git a/sae_analysis/visualizer/html/logits_histogram.html b/sae_analysis/visualizer/html/logits_histogram.html index 9e0589ea..a88ee6b3 100644 --- a/sae_analysis/visualizer/html/logits_histogram.html +++ b/sae_analysis/visualizer/html/logits_histogram.html @@ -110,4 +110,4 @@ // }); - \ No newline at end of file + diff --git a/sae_analysis/visualizer/html/token_template.html b/sae_analysis/visualizer/html/token_template.html index 9048dea6..49b7a7c7 100644 --- a/sae_analysis/visualizer/html/token_template.html +++ b/sae_analysis/visualizer/html/token_template.html @@ -29,4 +29,4 @@ - \ No newline at end of file + diff --git a/sae_analysis/visualizer/html_fns.py b/sae_analysis/visualizer/html_fns.py index 5d3955af..3ece35ac 100644 --- a/sae_analysis/visualizer/html_fns.py +++ b/sae_analysis/visualizer/html_fns.py @@ -8,20 +8,22 @@ from sae_analysis.visualizer.utils_fns import to_str_tokens -''' +""" Key feature of these functions: the arguments should be descriptive of their role in the actual HTML visualisation. If the arguments are super arcane features of the model data, this is bad! -''' +""" ROOT_DIR = Path(__file__).parent CSS_DIR = Path(__file__).parent / "css" -CSS = "\n".join([ - (CSS_DIR / "general.css").read_text(), - (CSS_DIR / "sequences.css").read_text(), - (CSS_DIR / "tables.css").read_text(), -]) +CSS = "\n".join( + [ + (CSS_DIR / "general.css").read_text(), + (CSS_DIR / "sequences.css").read_text(), + (CSS_DIR / "tables.css").read_text(), + ] +) HTML_DIR = Path(__file__).parent / "html" HTML_TOKEN = (HTML_DIR / "token_template.html").read_text() @@ -32,35 +34,32 @@ HTML_HOVERTEXT_SCRIPT = (HTML_DIR / "hovertext_script.html").read_text() -BG_COLOR_MAP = colors.LinearSegmentedColormap.from_list("bg_color_map", ["white", "darkorange"]) +BG_COLOR_MAP = colors.LinearSegmentedColormap.from_list( + "bg_color_map", ["white", "darkorange"] +) def generate_tok_html( vocab_dict: dict, - this_token: str, underline_color: str, bg_color: str, is_bold: bool = False, - feat_act: float = 0.0, contribution_to_loss: float = 0.0, pos_ids: List[int] = [0, 0, 0, 0, 0], pos_val: List[float] = [0.0, 0.0, 0.0, 0.0, 0.0], neg_ids: List[int] = [0, 0, 0, 0, 0], neg_val: List[float] = [0.0, 0.0, 0.0, 0.0, 0.0], - - ): - ''' + """ Creates a single sequence visualisation, by reading from the `token_template.html` file. Currently, a bunch of things are randomly chosen rather than actually calculated (we're going for proof of concept here). - ''' + """ html_output = ( - HTML_TOKEN - .replace("this_token", to_str_tokens(vocab_dict, this_token)) + HTML_TOKEN.replace("this_token", to_str_tokens(vocab_dict, this_token)) .replace("feat_activation", f"{feat_act:+.3f}") .replace("feature_ablation", f"{contribution_to_loss:+.3f}") .replace("font_weight", "bold" if is_bold else "normal") @@ -70,7 +69,7 @@ def generate_tok_html( # Figure out if the activations were zero on previous token, i.e. no predictions were affected is_empty = len(pos_ids) + len(neg_ids) == 0 - + # Get the string tokens pos_str = [to_str_tokens(vocab_dict, i) for i in pos_ids] neg_str = [to_str_tokens(vocab_dict, i) for i in neg_ids] @@ -80,31 +79,51 @@ def generate_tok_html( neg_str.extend([""] * 5) pos_val.extend([0.0] * 5) neg_val.extend([0.0] * 5) - + # Make all the substitutions - html_output = re.sub("pos_str_(\d)", lambda m: pos_str[int(m.group(1))].replace(" ", " "), html_output) - html_output = re.sub("neg_str_(\d)", lambda m: neg_str[int(m.group(1))].replace(" ", " "), html_output) - html_output = re.sub("pos_val_(\d)", lambda m: f"{pos_val[int(m.group(1))]:+.3f}", html_output) - html_output = re.sub("neg_val_(\d)", lambda m: f"{neg_val[int(m.group(1))]:+.3f}", html_output) + html_output = re.sub( + "pos_str_(\d)", + lambda m: pos_str[int(m.group(1))].replace(" ", " "), + html_output, + ) + html_output = re.sub( + "neg_str_(\d)", + lambda m: neg_str[int(m.group(1))].replace(" ", " "), + html_output, + ) + html_output = re.sub( + "pos_val_(\d)", lambda m: f"{pos_val[int(m.group(1))]:+.3f}", html_output + ) + html_output = re.sub( + "neg_val_(\d)", lambda m: f"{neg_val[int(m.group(1))]:+.3f}", html_output + ) # If the effect on loss is nothing (because feature isn't active), replace the HTML output with smth saying this if is_empty: html_output = ( - html_output - .replace('
', '