Skip to content

Commit

Permalink
update sae loading code
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Apr 14, 2024
1 parent 96b1e12 commit 356a8ef
Showing 1 changed file with 56 additions and 16 deletions.
72 changes: 56 additions & 16 deletions sae_lens/toolkit/pretrained_saes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from typing import Optional, Tuple

import torch
from huggingface_hub import hf_hub_download, list_files_info
Expand All @@ -17,36 +18,75 @@ def load_sparsity(path: str) -> torch.Tensor:
return sparsity


def get_gpt2_res_jb_saes() -> (
tuple[dict[str, SparseAutoencoder], dict[str, torch.Tensor]]
):
def download_sae_from_hf(
repo_id: str = "jbloom/GPT2-Small-SAEs-Reformatted",
folder_name: str = "blocks.0.hook_resid_pre",
force_download: bool = False,
) -> Tuple[str, str, Optional[str]]:

FILENAME = f"{folder_name}/cfg.json"
cfg_path = hf_hub_download(
repo_id=repo_id, filename=FILENAME, force_download=force_download
)

FILENAME = f"{folder_name}/sae_weights.safetensors"
sae_path = hf_hub_download(
repo_id=repo_id, filename=FILENAME, force_download=force_download
)

try:
FILENAME = f"{folder_name}/sparsity.safetensors"
sparsity_path = hf_hub_download(
repo_id=repo_id, filename=FILENAME, force_download=force_download
)
except: # noqa
sparsity_path = None

return cfg_path, sae_path, sparsity_path


def load_sae_from_local_path(path: str) -> Tuple[SparseAutoencoder, torch.Tensor]:
sae = SparseAutoencoder.load_from_pretrained(path)
sparsity = load_sparsity(path)
return sae, sparsity


def get_gpt2_res_jb_saes(
hook_point: Optional[str] = None,
device: str = "cpu",
) -> tuple[dict[str, SparseAutoencoder], dict[str, torch.Tensor]]:
"""
Download the sparse autoencoders for the GPT2-Small model with residual connections
from the repository of jbloom. You can specify a hook_point to download only one
of the sparse autoencoders if desired.
"""

GPT2_SMALL_RESIDUAL_SAES_REPO_ID = "jbloom/GPT2-Small-SAEs-Reformatted"
GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS = [
f"blocks.{layer}.hook_resid_pre" for layer in range(12)
] + ["blocks.11.hook_resid_post"]

if hook_point is not None:
assert hook_point in GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS, (
f"hook_point must be one of {GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS}"
f"but got {hook_point}"
)
GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS = [hook_point]

saes = {}
sparsities = {}
for hook_point in tqdm(GPT2_SMALL_RESIDUAL_SAES_HOOK_POINTS):
# download the files required:
FILENAME = f"{hook_point}/cfg.json"
hf_hub_download(repo_id=GPT2_SMALL_RESIDUAL_SAES_REPO_ID, filename=FILENAME)

FILENAME = f"{hook_point}/sae_weights.safetensors"
path = hf_hub_download(
repo_id=GPT2_SMALL_RESIDUAL_SAES_REPO_ID, filename=FILENAME
)

FILENAME = f"{hook_point}/sparsity.safetensors"
path = hf_hub_download(
repo_id=GPT2_SMALL_RESIDUAL_SAES_REPO_ID, filename=FILENAME
_, sae_path, _ = download_sae_from_hf(
repo_id=GPT2_SMALL_RESIDUAL_SAES_REPO_ID, folder_name=hook_point
)

# Then use our function to download the files
folder_path = os.path.dirname(path)
sae = SparseAutoencoder.load_from_pretrained(folder_path)
folder_path = os.path.dirname(sae_path)
sae = SparseAutoencoder.load_from_pretrained(folder_path, device=device)
sparsity = load_sparsity(folder_path)
sparsity = sparsity.to(device)
saes[hook_point] = sae
sparsities[hook_point] = sparsity

Expand Down

0 comments on commit 356a8ef

Please sign in to comment.