Skip to content

Commit

Permalink
Merge pull request #16 from chanind/flake-default-rules
Browse files Browse the repository at this point in the history
chore: enable full flake8 default rules list
  • Loading branch information
jbloomAus authored Feb 28, 2024
2 parents 496f7b4 + 19886e2 commit ad84706
Show file tree
Hide file tree
Showing 16 changed files with 51 additions and 72 deletions.
7 changes: 4 additions & 3 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
[flake8]
ignore = E203, E266, E501, W503
extend-ignore = E203, E266, E501, W503, E721, F722, E731
max-line-length = 79
max-complexity = 10
select = E9, F63, F7, F82
max-complexity = 25
extend-select = E9, F63, F7, F82
show-source = true
statistics = true
exclude = ./sae_training/geom_median/
6 changes: 1 addition & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@ jobs:
- name: Install dependencies
run: poetry install --no-interaction
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
poetry run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
poetry run flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
run: poetry run flake8 .
- name: black code formatting
run: poetry run black . --check
- name: isort linting
Expand Down
3 changes: 3 additions & 0 deletions sae_analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# flake8: noqa: E402
# TODO: are these sys.path.append calls really necessary?

import sys

sys.path.append("..")
Expand Down
26 changes: 12 additions & 14 deletions sae_analysis/visualizer/data_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint

Arr = np.ndarray

from sae_analysis.visualizer.html_fns import (
CSS,
HTML_HOVERTEXT_SCRIPT,
Expand All @@ -39,6 +37,8 @@
to_str_tokens,
)

Arr = np.ndarray

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Expand Down Expand Up @@ -497,7 +497,7 @@ def __init__(self):
self.x2_sum = 0
self.y2_sum = 0

def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]): # noqa
def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]):
assert x.ndim == 2 and y.ndim == 2, "Both x and y should be 2D"
assert (
x.shape[-1] == y.shape[-1]
Expand All @@ -510,7 +510,7 @@ def update(self, x: Float[Tensor, "X N"], y: Float[Tensor, "Y N"]): # noqa
self.x2_sum += einops.reduce(x**2, "X N -> X", "sum")
self.y2_sum += einops.reduce(y**2, "Y N -> Y", "sum")

def corrcoef(self) -> Tuple[Float[Tensor, "X Y"], Float[Tensor, "X Y"]]: # noqa
def corrcoef(self) -> Tuple[Float[Tensor, "X Y"], Float[Tensor, "X Y"]]:
cossim_numer = self.xy_sum
cossim_denom = torch.sqrt(torch.outer(self.x2_sum, self.y2_sum)) + 1e-6
cossim = cossim_numer / cossim_denom
Expand Down Expand Up @@ -549,7 +549,7 @@ def get_feature_data(
hook_point: str,
hook_point_layer: int,
hook_point_head_index: Optional[int],
tokens: Int[Tensor, "batch seq"], # noqa
tokens: Int[Tensor, "batch seq"],
feature_idx: Union[int, List[int]],
max_batch_size: Optional[int] = None,
left_hand_k: int = 3,
Expand Down Expand Up @@ -624,10 +624,8 @@ def get_feature_data(

# ! Define hook function to perform feature ablation

def hook_fn_act_post(
act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint # noqa
): # noqa
"""
def hook_fn_act_post(act_post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint):
r"""
Encoder has learned x^j \approx b + \sum_i f_i(x^j)d_i where:
- f_i are the feature activations
- d_i are the feature output directions
Expand Down Expand Up @@ -663,10 +661,9 @@ def hook_fn_act_post(
# )

def hook_fn_query(
hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPoint # noqa
hook_q: Float[Tensor, "batch seq n_head d_head"], hook: HookPoint
):
"""
r"""
Replace act_post with projection of query onto the resid by W_k^T.
Encoder has learned x^j \approx b + \sum_i f_i(x^j)d_i where:
- f_i are the feature activations
Expand Down Expand Up @@ -698,7 +695,7 @@ def hook_fn_query(
)

def hook_fn_resid_post(
resid_post: Float[Tensor, "batch seq d_model"], hook: HookPoint # noqa
resid_post: Float[Tensor, "batch seq d_model"], hook: HookPoint
):
"""
This hook function stores the residual activations, which we'll need later on to calculate the effect of feature ablation.
Expand Down Expand Up @@ -1020,7 +1017,8 @@ def hook_fn_resid_post(
save_obj, filename=filename, save_type=save_type
)
t1 = time.time()
loaded_obj = FeatureData.load_batch(
# TODO: is this doing anything? the result isn't read
FeatureData.load_batch(
filename, save_type=save_type, vocab_dict=vocab_dict
)
t2 = time.time()
Expand Down
12 changes: 6 additions & 6 deletions sae_analysis/visualizer/html_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,20 @@ def generate_tok_html(

# Make all the substitutions
html_output = re.sub(
"pos_str_(\d)",
r"pos_str_(\d)",
lambda m: pos_str[int(m.group(1))].replace(" ", " "),
html_output,
)
html_output = re.sub(
"neg_str_(\d)",
r"neg_str_(\d)",
lambda m: neg_str[int(m.group(1))].replace(" ", " "),
html_output,
)
html_output = re.sub(
"pos_val_(\d)", lambda m: f"{pos_val[int(m.group(1))]:+.3f}", html_output
r"pos_val_(\d)", lambda m: f"{pos_val[int(m.group(1))]:+.3f}", html_output
)
html_output = re.sub(
"neg_val_(\d)", lambda m: f"{neg_val[int(m.group(1))]:+.3f}", html_output
r"neg_val_(\d)", lambda m: f"{neg_val[int(m.group(1))]:+.3f}", html_output
)

# If the effect on loss is nothing (because feature isn't active), replace the HTML output with smth saying this
Expand Down Expand Up @@ -235,7 +235,7 @@ def generate_tables_html(
if myformat is None
else format(mylist[int(m.group(1))], myformat)
)
html_output = re.sub(letter + "(\d)", fn, html_output, count=3)
html_output = re.sub(letter + r"(\d)", fn, html_output, count=3)

html_output_2 = HTML_LOGIT_TABLES

Expand All @@ -258,7 +258,7 @@ def generate_tables_html(
fn = lambda m: format(mylist[int(m.group(1))], "+.2f")
elif letter == "C":
fn = lambda m: str(mylist[int(m.group(1))])
html_output_2 = re.sub(letter + "(\d)", fn, html_output_2, count=10)
html_output_2 = re.sub(letter + r"(\d)", fn, html_output_2, count=10)

return (html_output, html_output_2)

Expand Down
1 change: 0 additions & 1 deletion sae_analysis/visualizer/model_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm.notebook as tqdm
from transformer_lens import utils

DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
Expand Down
13 changes: 5 additions & 8 deletions sae_analysis/visualizer/utils_fns.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
import re
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import einops
import numpy as np
import torch
from eindex import eindex
from jaxtyping import Float, Int
from torch import Tensor
from transformer_lens import HookedTransformer

Arr = np.ndarray


def k_largest_indices(
x: Float[Tensor, "rows cols"], # noqa
x: Float[Tensor, "rows cols"],
k: int,
largest: bool = True,
buffer: Tuple[int, int] = (5, 5),
) -> Int[Tensor, "k 2"]: # noqa
) -> Int[Tensor, "k 2"]:
"""w
Given a 2D array, returns the indices of the top or bottom `k` elements.
Expand All @@ -40,11 +37,11 @@ def sample_unique_indices(large_number, small_number):


def random_range_indices(
x: Float[Tensor, "batch seq"], # noqa
x: Float[Tensor, "batch seq"],
bounds: Tuple[float, float],
k: int,
buffer: Tuple[int, int] = (5, 5),
) -> Int[Tensor, "k 2"]: # noqa
) -> Int[Tensor, "k 2"]:
"""
Given a 2D array, returns the indices of `k` elements whose values are in the range `bounds`.
Will return fewer than `k` values if there aren't enough values in the range.
Expand Down
1 change: 0 additions & 1 deletion sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformer_lens import HookedTransformer


Expand Down
3 changes: 0 additions & 3 deletions sae_training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os

import torch
import wandb

# from sae_training.activation_store import ActivationStore
Expand Down
7 changes: 1 addition & 6 deletions sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,10 @@
import gzip
import os
import pickle
from functools import partial

import einops
import torch
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from tqdm import tqdm
from torch import nn
from transformer_lens.hook_points import HookedRootModule, HookPoint

from sae_training.geom_median.src.geom_median.torch import compute_geometric_median
Expand Down
1 change: 0 additions & 1 deletion sae_training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import einops
import torch
import wandb
from transformer_lens import HookedTransformer

from sae_training.sparse_autoencoder import SparseAutoencoder
from sae_training.toy_models import Config as ToyConfig
Expand Down
31 changes: 14 additions & 17 deletions sae_training/toy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@

import einops
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import torch as t
from IPython.display import clear_output
from jaxtyping import Float, Int
from jaxtyping import Float
from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.widgets import Slider # , Button
from plotly.subplots import make_subplots
from torch import Tensor, nn
from torch.nn import functional as F
from tqdm import tqdm
Expand Down Expand Up @@ -51,8 +48,8 @@ class Config:


class Model(nn.Module):
W: Float[Tensor, "n_instances n_hidden n_features"] # noqa
b_final: Float[Tensor, "n_instances n_features"] # noqa
W: Float[Tensor, "n_instances n_hidden n_features"]
b_final: Float[Tensor, "n_instances n_features"]
# Our linear map is x -> ReLU(W.T @ W @ x + b_final)

def __init__(
Expand Down Expand Up @@ -89,8 +86,8 @@ def __init__(
self.to(device)

def forward(
self, features: Float[Tensor, "... instances features"] # noqa
) -> Float[Tensor, "... instances features"]: # noqa
self, features: Float[Tensor, "... instances features"]
) -> Float[Tensor, "... instances features"]:
hidden = einops.einsum(
features,
self.W,
Expand Down Expand Up @@ -119,7 +116,7 @@ def forward(

def generate_correlated_features(
self, batch_size, n_correlated_pairs
) -> Float[Tensor, "batch_size instances features"]: # noqa
) -> Float[Tensor, "batch_size instances features"]:
"""
Generates a batch of correlated features.
Each output[i, j, 2k] and output[i, j, 2k + 1] are correlated, i.e. one is present iff the other is present.
Expand All @@ -141,7 +138,7 @@ def generate_correlated_features(

def generate_anticorrelated_features(
self, batch_size, n_anticorrelated_pairs
) -> Float[Tensor, "batch_size instances features"]: # noqa
) -> Float[Tensor, "batch_size instances features"]:
"""
Generates a batch of anti-correlated features.
Each output[i, j, 2k] and output[i, j, 2k + 1] are anti-correlated, i.e. one is present iff the other is absent.
Expand Down Expand Up @@ -178,7 +175,7 @@ def generate_anticorrelated_features(

def generate_uncorrelated_features(
self, batch_size, n_uncorrelated
) -> Float[Tensor, "batch_size instances features"]: # noqa
) -> Float[Tensor, "batch_size instances features"]:
"""
Generates a batch of uncorrelated features.
"""
Expand All @@ -193,7 +190,7 @@ def generate_uncorrelated_features(

def generate_batch(
self, batch_size
) -> Float[Tensor, "batch_size instances features"]: # noqa
) -> Float[Tensor, "batch_size instances features"]:
"""
Generates a batch of data, with optional correlated & anticorrelated features.
"""
Expand Down Expand Up @@ -222,9 +219,9 @@ def generate_batch(

def calculate_loss(
self,
out: Float[Tensor, "batch instances features"], # noqa
batch: Float[Tensor, "batch instances features"], # noqa
) -> Float[Tensor, ""]: # noqa
out: Float[Tensor, "batch instances features"],
batch: Float[Tensor, "batch instances features"],
) -> Float[Tensor, ""]:
"""
Calculates the loss for a given batch, using this loss described in the Toy Models paper:
Expand Down Expand Up @@ -278,7 +275,7 @@ def optimize(


def plot_features_in_2d(
values: Float[Tensor, "timesteps instances d_hidden feats"], # noqa
values: Float[Tensor, "timesteps instances d_hidden feats"],
colors=None, # shape [timesteps instances feats]
title: Optional[str] = None,
subplot_titles: Optional[List[str]] = None,
Expand Down Expand Up @@ -431,7 +428,7 @@ def play(event):


def parse_colors_for_superposition_plot(
colors: Optional[Union[Tuple[int, int], Float[Tensor, "instances feats"]]], # noqa
colors: Optional[Union[Tuple[int, int], Float[Tensor, "instances feats"]]],
n_instances: int,
n_feats: int,
) -> List[List[str]]:
Expand Down
3 changes: 0 additions & 3 deletions sae_training/train_sae_on_toy_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import einops
import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

from sae_training.sparse_autoencoder import SparseAutoencoder
from sae_training.toy_models import Model as ToyModel


def train_toy_sae(
Expand Down Expand Up @@ -65,7 +63,6 @@ def train_toy_sae(
)

l0 = (feature_acts > 0).float().sum(-1).mean()
current_learning_rate = optimizer.param_groups[0]["lr"]
l2_norm = torch.norm(feature_acts, dim=1).mean()

l2_norm_in = torch.norm(batch, dim=-1)
Expand Down
2 changes: 2 additions & 0 deletions scripts/generate_dashboards.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# flake8: noqa: E402
# TODO: are these sys.path.append calls really necessary?
import sys

sys.path.append("..")
Expand Down
1 change: 0 additions & 1 deletion tests/benchmark/test_toy_model_sae_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import torch

from sae_training.toy_model_runner import SAEToyModelRunnerConfig, toy_model_sae_runner
Expand Down
Loading

0 comments on commit ad84706

Please sign in to comment.