Skip to content

Commit

Permalink
consolidates with SAE class load_legacy function & adds test
Browse files Browse the repository at this point in the history
  • Loading branch information
evanhanders committed Apr 18, 2024
1 parent fda2b57 commit 0f85ded
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 22 deletions.
36 changes: 16 additions & 20 deletions sae_lens/toolkit/pretrained_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from sae_lens.training.utils import BackwardsCompatiblePickleClass


def load_sparsity(path: str) -> torch.Tensor:
Expand Down Expand Up @@ -138,36 +137,33 @@ def convert_connor_rob_sae_to_our_saelens_format(

def convert_old_to_modern_saelens_format(
pytorch_file: str,
out_dir: str = None,
out_folder: str = None,
force: bool = False
):
"""
Reads a pretrained SAE from the old pickle-style SAELens .pt format,
then saves a modern-format SAELens folder of that SAE in out_dir.
Returns the loaded autoencoder.
Reads a pretrained SAE from the old pickle-style SAELens .pt format, then saves a modern-format SAELens SAE.
Arguments:
----------
pytorch_file: str
Path of old format file to open.
out_folder: str, optional
Path where new SAE will be stored; if None, out_folder = pytorch_file with the '.pt' removed.
force: bool, optional
If out_folder already exists, this function will not save unless force=True.
"""
file_path = pathlib.Path(pytorch_file)
if out_dir is None:
out_dir = file_path.parent
if out_folder is None:
out_folder = file_path.parent/file_path.stem
else:
out_dir = pathlib.Path(out_dir)
out_folder = out_dir/file_path.stem
out_folder = pathlib.Path(out_folder)
if (not force) and out_folder.exists():
raise FileExistsError(f"{out_folder} already exists and force=False")
out_folder.mkdir(exist_ok=True, parents=True)

#Load old data, construct modern config
old_sae_data = torch.load(file_path, pickle_module=BackwardsCompatiblePickleClass)
cfg = LanguageModelSAERunnerConfig(dtype=old_sae_data['cfg'].dtype)
for k in cfg.__dataclass_fields__:
if hasattr(old_sae_data['cfg'], k):
setattr(cfg, k, getattr(old_sae_data['cfg'], k))

#Get modern SAE object
autoencoder = SparseAutoencoder(cfg)
autoencoder.load_state_dict(old_sae_data['state_dict'])
#Load model & save in new format.
autoencoder = SparseAutoencoder.load_from_pretrained_legacy(str(file_path))
autoencoder.save_model(out_folder)
return autoencoder

def get_gpt2_small_ckrk_attn_out_saes() -> dict[str, SparseAutoencoder]:

Expand Down
6 changes: 4 additions & 2 deletions sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,12 @@ def load_from_pretrained_legacy(cls, path: str):
)
state_dict["cfg"].device = "mps"
else:
state_dict = torch.load(path)
state_dict = torch.load(
path,
pickle_module=BackwardsCompatiblePickleClass
)
except Exception as e:
raise IOError(f"Error loading the state dictionary from .pt file: {e}")

elif path.endswith(".pkl.gz"):
try:
with gzip.open(path, "rb") as f:
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/toolkit/test_pretrained_saes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pathlib
import shutil

import pytest
import torch

from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from sae_lens.toolkit import pretrained_saes
from sae_lens.training.config import LanguageModelSAERunnerConfig

def test_convert_old_to_modern_saelens_format():
out_dir = pathlib.Path('unit_test_tmp')
out_dir.mkdir(exist_ok=True)
legacy_out_file = str(out_dir/'test.pt')
new_out_folder = str(out_dir/'test')

#Make an SAE, save old version
cfg = LanguageModelSAERunnerConfig(
dtype=torch.float32,
hook_point = 'blocks.0.hook_mlp_out',
)
old_sae = SparseAutoencoder(cfg)
old_sae.save_model_legacy(legacy_out_file)

#convert file format
pretrained_saes.convert_old_to_modern_saelens_format(
legacy_out_file,
new_out_folder
)

#Load from new converted file
new_sae = SparseAutoencoder.load_from_pretrained(
new_out_folder
)
shutil.rmtree(out_dir) #cleanup

#Test similarity
assert torch.allclose(new_sae.W_enc, old_sae.W_enc)
assert torch.allclose(new_sae.W_dec, old_sae.W_dec)

0 comments on commit 0f85ded

Please sign in to comment.