diff --git a/.flake8 b/.flake8 index 138a4973..e7169d49 100644 --- a/.flake8 +++ b/.flake8 @@ -1,7 +1,8 @@ [flake8] -ignore = E203, E266, E501, W503 +extend-ignore = E203, E266, E501, W503, E721, F722, E731 max-line-length = 79 -max-complexity = 10 -select = E9, F63, F7, F82 +max-complexity = 25 +extend-select = E9, F63, F7, F82 show-source = true statistics = true +exclude = ./sae_training/geom_median/ diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 748aaef8..7316b376 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -47,11 +47,7 @@ jobs: - name: Install dependencies run: poetry install --no-interaction - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - poetry run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - poetry run flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + run: poetry run flake8 . - name: black code formatting run: poetry run black . --check - name: isort linting diff --git a/sae_analysis/dashboard_runner.py b/sae_analysis/dashboard_runner.py index 815c6b14..4a3815ab 100644 --- a/sae_analysis/dashboard_runner.py +++ b/sae_analysis/dashboard_runner.py @@ -1,3 +1,6 @@ +# flake8: noqa: E402 +# TODO: are these sys.path.append calls really necessary? + import sys sys.path.append("..") diff --git a/sae_analysis/visualizer/data_fns.py b/sae_analysis/visualizer/data_fns.py index effe6f3f..7deb5d94 100644 --- a/sae_analysis/visualizer/data_fns.py +++ b/sae_analysis/visualizer/data_fns.py @@ -20,8 +20,6 @@ from transformer_lens import HookedTransformer, utils from transformer_lens.hook_points import HookPoint -Arr = np.ndarray - from sae_analysis.visualizer.html_fns import ( CSS, HTML_HOVERTEXT_SCRIPT, @@ -39,6 +37,8 @@ to_str_tokens, ) +Arr = np.ndarray + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -497,7 +497,7 @@ def __init__(self): self.x2_sum = 0 self.y2_sum = 0 - def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]): # noqa + def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]): assert x.ndim == 2 and y.ndim == 2, "Both x and y should be 2D" assert ( x.shape[-1] == y.shape[-1] @@ -510,7 +510,7 @@ def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]): # noqa self.x2_sum += einops.reduce(x**2, "X N -> X", "sum") self.y2_sum += einops.reduce(y**2, "Y N -> Y", "sum") - def corrcoef(self) -> Tuple[Float[Tensor, "X Y"], Float[Tensor, "X Y"]]: # noqa + def corrcoef(self) -> Tuple[Float[Tensor, "X Y"], Float[Tensor, "X Y"]]: cossim_numer = self.xy_sum cossim_denom = torch.sqrt(torch.outer(self.x2_sum, self.y2_sum)) + 1e-6 cossim = cossim_numer / cossim_denom @@ -549,7 +549,7 @@ def get_feature_data( hook_point: str, hook_point_layer: int, hook_point_head_index: Optional[int], - tokens: Int[Tensor, "batch seq"], # noqa + tokens: Int[Tensor, "batch seq"], feature_idx: Union[int, List[int]], max_batch_size: Optional[int] = None, left_hand_k: int = 3, @@ -624,10 +624,8 @@ def get_feature_data( # ! Define hook function to perform feature ablation - def hook_fn_act_post( - act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint # noqa - ): # noqa - """ + def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint): + r""" Encoder has learned x^j \approx b + \sum_i f_i(x^j)d_i where: - f_i are the feature activations - d_i are the feature output directions @@ -663,10 +661,9 @@ def hook_fn_act_post( # ) def hook_fn_query( - hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPoint # noqa + hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPoint ): - """ - + r""" Replace act_post with projection of query onto the resid by W_k^T. Encoder has learned x^j \approx b + \sum_i f_i(x^j)d_i where: - f_i are the feature activations @@ -698,7 +695,7 @@ def hook_fn_query( ) def hook_fn_resid_post( - resid_post: Float[Tensor, "batch seq d_model"], hook: HookPoint # noqa + resid_post: Float[Tensor, "batch seq d_model"], hook: HookPoint ): """ This hook function stores the residual activations, which we'll need later on to calculate the effect of feature ablation. @@ -1020,7 +1017,8 @@ def hook_fn_resid_post( save_obj, filename=filename, save_type=save_type ) t1 = time.time() - loaded_obj = FeatureData.load_batch( + # TODO: is this doing anything? the result isn't read + FeatureData.load_batch( filename, save_type=save_type, vocab_dict=vocab_dict ) t2 = time.time() diff --git a/sae_analysis/visualizer/html_fns.py b/sae_analysis/visualizer/html_fns.py index 9b0e546d..2b524dcc 100644 --- a/sae_analysis/visualizer/html_fns.py +++ b/sae_analysis/visualizer/html_fns.py @@ -82,20 +82,20 @@ def generate_tok_html( # Make all the substitutions html_output = re.sub( - "pos_str_(\d)", + r"pos_str_(\d)", lambda m: pos_str[int(m.group(1))].replace(" ", " "), html_output, ) html_output = re.sub( - "neg_str_(\d)", + r"neg_str_(\d)", lambda m: neg_str[int(m.group(1))].replace(" ", " "), html_output, ) html_output = re.sub( - "pos_val_(\d)", lambda m: f"{pos_val[int(m.group(1))]:+.3f}", html_output + r"pos_val_(\d)", lambda m: f"{pos_val[int(m.group(1))]:+.3f}", html_output ) html_output = re.sub( - "neg_val_(\d)", lambda m: f"{neg_val[int(m.group(1))]:+.3f}", html_output + r"neg_val_(\d)", lambda m: f"{neg_val[int(m.group(1))]:+.3f}", html_output ) # If the effect on loss is nothing (because feature isn't active), replace the HTML output with smth saying this @@ -235,7 +235,7 @@ def generate_tables_html( if myformat is None else format(mylist[int(m.group(1))], myformat) ) - html_output = re.sub(letter + "(\d)", fn, html_output, count=3) + html_output = re.sub(letter + r"(\d)", fn, html_output, count=3) html_output_2 = HTML_LOGIT_TABLES @@ -258,7 +258,7 @@ def generate_tables_html( fn = lambda m: format(mylist[int(m.group(1))], "+.2f") elif letter == "C": fn = lambda m: str(mylist[int(m.group(1))]) - html_output_2 = re.sub(letter + "(\d)", fn, html_output_2, count=10) + html_output_2 = re.sub(letter + r"(\d)", fn, html_output_2, count=10) return (html_output, html_output_2) diff --git a/sae_analysis/visualizer/model_fns.py b/sae_analysis/visualizer/model_fns.py index 3a4d6879..ebf4b570 100644 --- a/sae_analysis/visualizer/model_fns.py +++ b/sae_analysis/visualizer/model_fns.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import tqdm.notebook as tqdm from transformer_lens import utils DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} diff --git a/sae_analysis/visualizer/utils_fns.py b/sae_analysis/visualizer/utils_fns.py index 51ab8980..7b0c59b6 100644 --- a/sae_analysis/visualizer/utils_fns.py +++ b/sae_analysis/visualizer/utils_fns.py @@ -1,23 +1,20 @@ import re -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union -import einops import numpy as np import torch -from eindex import eindex from jaxtyping import Float, Int from torch import Tensor -from transformer_lens import HookedTransformer Arr = np.ndarray def k_largest_indices( - x: Float[Tensor, "rows cols"], # noqa + x: Float[Tensor, "rows cols"], k: int, largest: bool = True, buffer: Tuple[int, int] = (5, 5), -) -> Int[Tensor, "k 2"]: # noqa +) -> Int[Tensor, "k 2"]: """w Given a 2D array, returns the indices of the top or bottom `k` elements. @@ -40,11 +37,11 @@ def sample_unique_indices(large_number, small_number): def random_range_indices( - x: Float[Tensor, "batch seq"], # noqa + x: Float[Tensor, "batch seq"], bounds: Tuple[float, float], k: int, buffer: Tuple[int, int] = (5, 5), -) -> Int[Tensor, "k 2"]: # noqa +) -> Int[Tensor, "k 2"]: """ Given a 2D array, returns the indices of `k` elements whose values are in the range `bounds`. Will return fewer than `k` values if there aren't enough values in the range. diff --git a/sae_training/activations_store.py b/sae_training/activations_store.py index 9bae6824..103aae19 100644 --- a/sae_training/activations_store.py +++ b/sae_training/activations_store.py @@ -3,7 +3,6 @@ import torch from datasets import load_dataset from torch.utils.data import DataLoader -from tqdm import tqdm from transformer_lens import HookedTransformer diff --git a/sae_training/lm_runner.py b/sae_training/lm_runner.py index ec038de8..9b51e960 100644 --- a/sae_training/lm_runner.py +++ b/sae_training/lm_runner.py @@ -1,6 +1,3 @@ -import os - -import torch import wandb # from sae_training.activation_store import ActivationStore diff --git a/sae_training/sparse_autoencoder.py b/sae_training/sparse_autoencoder.py index 23abbf05..7d5c3bf3 100644 --- a/sae_training/sparse_autoencoder.py +++ b/sae_training/sparse_autoencoder.py @@ -5,15 +5,10 @@ import gzip import os import pickle -from functools import partial import einops import torch -import torch.nn.functional as F -from jaxtyping import Float -from torch import Tensor, nn -from torch.distributions.categorical import Categorical -from tqdm import tqdm +from torch import nn from transformer_lens.hook_points import HookedRootModule, HookPoint from sae_training.geom_median.src.geom_median.torch import compute_geometric_median diff --git a/sae_training/toy_model_runner.py b/sae_training/toy_model_runner.py index 7726a24f..67aa267a 100644 --- a/sae_training/toy_model_runner.py +++ b/sae_training/toy_model_runner.py @@ -3,7 +3,6 @@ import einops import torch import wandb -from transformer_lens import HookedTransformer from sae_training.sparse_autoencoder import SparseAutoencoder from sae_training.toy_models import Config as ToyConfig diff --git a/sae_training/toy_models.py b/sae_training/toy_models.py index 32671de4..384954f7 100644 --- a/sae_training/toy_models.py +++ b/sae_training/toy_models.py @@ -10,15 +10,12 @@ import einops import numpy as np -import plotly.express as px -import plotly.graph_objects as go import torch as t from IPython.display import clear_output -from jaxtyping import Float, Int +from jaxtyping import Float from matplotlib import pyplot as plt from matplotlib.animation import FuncAnimation from matplotlib.widgets import Slider # , Button -from plotly.subplots import make_subplots from torch import Tensor, nn from torch.nn import functional as F from tqdm import tqdm @@ -51,8 +48,8 @@ class Config: class Model(nn.Module): - W: Float[Tensor, "n_instances n_hidden n_features"] # noqa - b_final: Float[Tensor, "n_instances n_features"] # noqa + W: Float[Tensor, "n_instances n_hidden n_features"] + b_final: Float[Tensor, "n_instances n_features"] # Our linear map is x -> ReLU(W.T @ W @ x + b_final) def __init__( @@ -89,8 +86,8 @@ def __init__( self.to(device) def forward( - self, features: Float[Tensor, "... instances features"] # noqa - ) -> Float[Tensor, "... instances features"]: # noqa + self, features: Float[Tensor, "... instances features"] + ) -> Float[Tensor, "... instances features"]: hidden = einops.einsum( features, self.W, @@ -119,7 +116,7 @@ def forward( def generate_correlated_features( self, batch_size, n_correlated_pairs - ) -> Float[Tensor, "batch_size instances features"]: # noqa + ) -> Float[Tensor, "batch_size instances features"]: """ Generates a batch of correlated features. Each output[i, j, 2k] and output[i, j, 2k + 1] are correlated, i.e. one is present iff the other is present. @@ -141,7 +138,7 @@ def generate_correlated_features( def generate_anticorrelated_features( self, batch_size, n_anticorrelated_pairs - ) -> Float[Tensor, "batch_size instances features"]: # noqa + ) -> Float[Tensor, "batch_size instances features"]: """ Generates a batch of anti-correlated features. Each output[i, j, 2k] and output[i, j, 2k + 1] are anti-correlated, i.e. one is present iff the other is absent. @@ -178,7 +175,7 @@ def generate_anticorrelated_features( def generate_uncorrelated_features( self, batch_size, n_uncorrelated - ) -> Float[Tensor, "batch_size instances features"]: # noqa + ) -> Float[Tensor, "batch_size instances features"]: """ Generates a batch of uncorrelated features. """ @@ -193,7 +190,7 @@ def generate_uncorrelated_features( def generate_batch( self, batch_size - ) -> Float[Tensor, "batch_size instances features"]: # noqa + ) -> Float[Tensor, "batch_size instances features"]: """ Generates a batch of data, with optional correlated & anticorrelated features. """ @@ -222,9 +219,9 @@ def generate_batch( def calculate_loss( self, - out: Float[Tensor, "batch instances features"], # noqa - batch: Float[Tensor, "batch instances features"], # noqa - ) -> Float[Tensor, ""]: # noqa + out: Float[Tensor, "batch instances features"], + batch: Float[Tensor, "batch instances features"], + ) -> Float[Tensor, ""]: """ Calculates the loss for a given batch, using this loss described in the Toy Models paper: @@ -278,7 +275,7 @@ def optimize( def plot_features_in_2d( - values: Float[Tensor, "timesteps instances d_hidden feats"], # noqa + values: Float[Tensor, "timesteps instances d_hidden feats"], colors=None, # shape [timesteps instances feats] title: Optional[str] = None, subplot_titles: Optional[List[str]] = None, @@ -431,7 +428,7 @@ def play(event): def parse_colors_for_superposition_plot( - colors: Optional[Union[Tuple[int, int], Float[Tensor, "instances feats"]]], # noqa + colors: Optional[Union[Tuple[int, int], Float[Tensor, "instances feats"]]], n_instances: int, n_feats: int, ) -> List[List[str]]: diff --git a/sae_training/train_sae_on_toy_model.py b/sae_training/train_sae_on_toy_model.py index 5554d4db..c43b75c4 100644 --- a/sae_training/train_sae_on_toy_model.py +++ b/sae_training/train_sae_on_toy_model.py @@ -1,11 +1,9 @@ -import einops import torch import wandb from torch.utils.data import DataLoader from tqdm import tqdm from sae_training.sparse_autoencoder import SparseAutoencoder -from sae_training.toy_models import Model as ToyModel def train_toy_sae( @@ -65,7 +63,6 @@ def train_toy_sae( ) l0 = (feature_acts > 0).float().sum(-1).mean() - current_learning_rate = optimizer.param_groups[0]["lr"] l2_norm = torch.norm(feature_acts, dim=1).mean() l2_norm_in = torch.norm(batch, dim=-1) diff --git a/scripts/generate_dashboards.py b/scripts/generate_dashboards.py index 7da5dc02..bbfa0af7 100644 --- a/scripts/generate_dashboards.py +++ b/scripts/generate_dashboards.py @@ -1,3 +1,5 @@ +# flake8: noqa: E402 +# TODO: are these sys.path.append calls really necessary? import sys sys.path.append("..") diff --git a/tests/benchmark/test_toy_model_sae_runner.py b/tests/benchmark/test_toy_model_sae_runner.py index 811ac9c4..ed904d9f 100644 --- a/tests/benchmark/test_toy_model_sae_runner.py +++ b/tests/benchmark/test_toy_model_sae_runner.py @@ -1,4 +1,3 @@ -import pytest import torch from sae_training.toy_model_runner import SAEToyModelRunnerConfig, toy_model_sae_runner diff --git a/tests/unit/test_sparse_autoencoder.py b/tests/unit/test_sparse_autoencoder.py index 63422b09..e30b2103 100644 --- a/tests/unit/test_sparse_autoencoder.py +++ b/tests/unit/test_sparse_autoencoder.py @@ -92,7 +92,7 @@ def test_sparse_autoencoder_init(cfg): def test_save_model(cfg): with tempfile.TemporaryDirectory() as tmpdirname: # assert file does not exist - assert os.path.exists(tmpdirname + "/test.pt") == False + assert not os.path.exists(tmpdirname + "/test.pt") sparse_autoencoder = SparseAutoencoder(cfg) sparse_autoencoder.save_model(tmpdirname + "/test.pt") @@ -120,7 +120,7 @@ def test_save_model(cfg): def test_load_from_pretrained_pt(cfg): with tempfile.TemporaryDirectory() as tmpdirname: # assert file does not exist - assert os.path.exists(tmpdirname + "/test.pt") == False + assert not os.path.exists(tmpdirname + "/test.pt") sparse_autoencoder = SparseAutoencoder(cfg) sparse_autoencoder_state_dict = sparse_autoencoder.state_dict() @@ -152,7 +152,7 @@ def test_load_from_pretrained_pt(cfg): def test_load_from_pretrained_pkl_gz(cfg): with tempfile.TemporaryDirectory() as tmpdirname: # assert file does not exist - assert os.path.exists(tmpdirname + "/test.pkl.gz") == False + assert not os.path.exists(tmpdirname + "/test.pkl.gz") sparse_autoencoder = SparseAutoencoder(cfg) sparse_autoencoder_state_dict = sparse_autoencoder.state_dict()