Skip to content

Commit

Permalink
Merge pull request #13 from chanind/poetry
Browse files Browse the repository at this point in the history
chore: using poetry for dependency management
  • Loading branch information
jbloomAus authored Feb 26, 2024
2 parents 3727b5d + 465e003 commit 496f7b4
Show file tree
Hide file tree
Showing 29 changed files with 122 additions and 69 deletions.
16 changes: 9 additions & 7 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,20 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'
- name: Install Poetry
uses: snok/install-poetry@v1
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
run: poetry install --no-interaction
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
poetry run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
poetry run flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: black code formatting
run: poetry run black . --check
- name: isort linting
run: poetry run isort . --check-only --diff
- name: Run Unit Tests
run: |
make unit-test
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ ipython_config.py
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
Expand Down
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ This codebase contains training scripts and analysis code for Sparse AutoEncoder

## Set Up

```
conda create --name mats_sae_training python=3.11 -y
conda activate mats_sae_training
pip install -r requirements.txt
This project uses [Poetry](https://python-poetry.org/) for dependency management. Ensure Poetry is installed, then to install the dependencies, run:

```
poetry install
```

## Background
Expand Down
9 changes: 7 additions & 2 deletions makefile
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
format:
poetry run black .
poetry run isort .

check-format:
poetry run flake8 .
poetry run black --check .
poetry run isort --check-only --diff .


test:
make unit-test
make acceptance-test

unit-test:
pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/unit
poetry run pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/unit

acceptance-test:
pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/acceptance
poetry run pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/acceptance
36 changes: 36 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[tool.poetry]
name = "mats_sae_training"
version = "0.1.0"
description = "Training Sparse Autoencoders (SAEs)"
authors = ["Joseph Bloom"]
readme = "README.md"
packages = [{include = "sae_analysis"}, {include = "sae_training"}]

[tool.poetry.dependencies]
python = "^3.10"
transformer-lens = "^1.14.0"
transformers = "^4.38.1"
jupyter = "^1.0.0"
plotly = "^5.19.0"
plotly-express = "^0.4.1"
nbformat = "^5.9.2"
ipykernel = "^6.29.2"
matplotlib = "^3.8.3"
matplotlib-inline = "^0.1.6"
eindex = {git = "https://github.com/callummcdougall/eindex.git"}


[tool.poetry.group.dev.dependencies]
black = "^24.2.0"
pytest = "^8.0.2"
pytest-cov = "^4.1.0"
pre-commit = "^3.6.2"
flake8 = "^7.0.0"
isort = "^5.13.2"

[tool.isort]
profile = "black"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ nbformat==5.9.2
ipykernel==6.27.1
matplotlib==3.8.2
matplotlib-inline==0.1.6
pylint==3.0.2
flake8==7.0.0
isort==5.13.2
black==23.11.0
pytest==7.4.3
pytest-cov==4.1.0
Expand Down
6 changes: 2 additions & 4 deletions sae_analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import plotly
import plotly.express as px
import torch
import wandb
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_analysis.visualizer.data_fns import get_feature_data
from sae_training.utils import LMSparseAutoencoderSessionloader

Expand Down Expand Up @@ -148,9 +148,7 @@ def init_sae_session(self):
self.activation_store,
) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path)

def get_tokens(
self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6
):
def get_tokens(self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6):
"""
Get the tokens needed for dashboard generation.
"""
Expand Down
5 changes: 1 addition & 4 deletions sae_analysis/visualizer/data_fns.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import gzip
import json
import os
import pickle
import time
from collections import defaultdict
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Union

import einops
import numpy as np
Expand Down
4 changes: 2 additions & 2 deletions sae_analysis/visualizer/html_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def generate_tables_html(
],
[None, "+.2f", ".1%", None, "+.2f", "+.2f", None, "+.2f", "+.2f"],
):
fn = (
lambda m: str(mylist[int(m.group(1))])
fn = lambda m: (
str(mylist[int(m.group(1))])
if myformat is None
else format(mylist[int(m.group(1))], myformat)
)
Expand Down
8 changes: 4 additions & 4 deletions sae_analysis/visualizer/model_fns.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from transformer_lens import utils
import torch
import pprint
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm.notebook as tqdm
from dataclasses import dataclass

from transformer_lens import utils

DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}

Expand Down
12 changes: 6 additions & 6 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def get_buffer(self, n_batches_in_buffer):
taking_subset_of_file = True

# Add it to the buffer
new_buffer[
n_tokens_filled : n_tokens_filled + activations.shape[0]
] = activations
new_buffer[n_tokens_filled : n_tokens_filled + activations.shape[0]] = (
activations
)

# Update counters
n_tokens_filled += activations.shape[0]
Expand All @@ -235,9 +235,9 @@ def get_buffer(self, n_batches_in_buffer):
for refill_batch_idx_start in refill_iterator:
refill_batch_tokens = self.get_batch_tokens()
refill_activations = self.get_activations(refill_batch_tokens)
new_buffer[
refill_batch_idx_start : refill_batch_idx_start + batch_size
] = refill_activations
new_buffer[refill_batch_idx_start : refill_batch_idx_start + batch_size] = (
refill_activations
)

# pbar.update(1)

Expand Down
2 changes: 1 addition & 1 deletion sae_training/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import os

import torch
from transformer_lens import HookedTransformer
from tqdm import tqdm
from transformer_lens import HookedTransformer

from sae_training.activations_store import ActivationsStore
from sae_training.config import CacheActivationsRunnerConfig
Expand Down
11 changes: 6 additions & 5 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Optional

import torch

import wandb


Expand All @@ -22,9 +21,9 @@ class RunnerConfig(ABC):
is_dataset_tokenized: bool = True
context_size: int = 128
use_cached_activations: bool = False
cached_activations_path: Optional[
str
] = None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}"
cached_activations_path: Optional[str] = (
None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}"
)

# SAE Parameters
d_in: int = 512
Expand Down Expand Up @@ -61,7 +60,9 @@ class LanguageModelSAERunnerConfig(RunnerConfig):
# Training Parameters
l1_coefficient: float = 1e-3
lr: float = 3e-4
lr_scheduler_name: str = "constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
lr_scheduler_name: str = (
"constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
)
lr_warm_up_steps: int = 500
train_batch_size: int = 4096

Expand Down
7 changes: 5 additions & 2 deletions sae_training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import pandas as pd
import torch
import wandb
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.sparse_autoencoder import SparseAutoencoder

Expand All @@ -27,7 +27,10 @@ def run_evals(

# Get Reconstruction Score
losses_df = recons_loss_batched(
sparse_autoencoder, model, activation_store, n_batches = 10,
sparse_autoencoder,
model,
activation_store,
n_batches=10,
)

recons_score = losses_df["score"].mean()
Expand Down
2 changes: 1 addition & 1 deletion sae_training/geom_median/src/geom_median/numpy/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np

from . import utils
from .weiszfeld_array import geometric_median_array, geometric_median_per_component
from .weiszfeld_list_of_array import geometric_median_list_of_array
from . import utils


def compute_geometric_median(
Expand Down
1 change: 1 addition & 0 deletions sae_training/geom_median/src/geom_median/numpy/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from itertools import zip_longest

import numpy as np


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from types import SimpleNamespace

import numpy as np


def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
"""
Expand Down Expand Up @@ -36,9 +37,11 @@ def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):

return SimpleNamespace(
median=median,
termination="function value converged within tolerance"
if early_termination
else "maximum iterations reached",
termination=(
"function value converged within tolerance"
if early_termination
else "maximum iterations reached"
),
logs=logs,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from types import SimpleNamespace

import numpy as np


def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
"""
Expand Down Expand Up @@ -38,9 +39,11 @@ def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=

return SimpleNamespace(
median=median,
termination="function value converged within tolerance"
if early_termination
else "maximum iterations reached",
termination=(
"function value converged within tolerance"
if early_termination
else "maximum iterations reached"
),
logs=logs,
)

Expand Down
2 changes: 1 addition & 1 deletion sae_training/geom_median/src/geom_median/torch/main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch

from . import utils
from .weiszfeld_array import geometric_median_array, geometric_median_per_component
from .weiszfeld_list_of_array import geometric_median_list_of_array
from . import utils


def compute_geometric_median(
Expand Down
1 change: 1 addition & 0 deletions sae_training/geom_median/src/geom_median/torch/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from itertools import zip_longest

import torch


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
return SimpleNamespace(
median=median,
new_weights=new_weights,
termination="function value converged within tolerance"
if early_termination
else "maximum iterations reached",
termination=(
"function value converged within tolerance"
if early_termination
else "maximum iterations reached"
),
logs=logs,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from types import SimpleNamespace

import numpy as np
import torch
from types import SimpleNamespace


def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20):
Expand Down Expand Up @@ -43,9 +44,11 @@ def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=
return SimpleNamespace(
median=median,
new_weights=new_weights,
termination="function value converged within tolerance"
if early_termination
else "maximum iterations reached",
termination=(
"function value converged within tolerance"
if early_termination
else "maximum iterations reached"
),
logs=logs,
)

Expand Down
Loading

0 comments on commit 496f7b4

Please sign in to comment.