Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: load the same config from_pretrained and get_sae_config #361

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions sae_lens/toolkit/pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,9 @@
options = SAEConfigLoadOptions(
device=device,
force_download=force_download,
cfg_overrides=cfg_overrides,
)
cfg_dict = get_sae_config(release, sae_id=sae_id, options=options)
# Apply overrides if provided
if cfg_overrides is not None:
cfg_dict.update(cfg_overrides)
cfg_dict["device"] = device
cfg_dict = handle_config_defaulting(cfg_dict)

repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id=sae_id)

weights_filename = f"{folder_name}/sae_weights.safetensors"
Expand Down Expand Up @@ -116,6 +111,9 @@
with open(cfg_path, "r") as f:
cfg_dict = json.load(f)

if options.device is not None:
cfg_dict["device"] = options.device

return cfg_dict


Expand Down Expand Up @@ -310,7 +308,7 @@
else:
raise ValueError("Hook name not found in folder_name.")

return {
cfg = {
"architecture": "jumprelu",
"d_in": d_in,
"d_sae": d_sae,
Expand All @@ -329,6 +327,10 @@
"apply_b_dec_to_input": False,
"normalize_activations": None,
}
if options.device is not None:
cfg["device"] = options.device

Check warning on line 331 in sae_lens/toolkit/pretrained_sae_loaders.py

View check run for this annotation

Codecov / codecov/patch

sae_lens/toolkit/pretrained_sae_loaders.py#L331

Added line #L331 was not covered by tests

return cfg


def gemma_2_sae_loader(
Expand Down Expand Up @@ -470,9 +472,17 @@
saes_directory = get_pretrained_saes_directory()
sae_info = saes_directory.get(release, None)
repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id=sae_id)
cfg_overrides = options.cfg_overrides or {}
if sae_info is not None:
cfg_overrides = {**(sae_info.config_overrides or {}), **cfg_overrides}

conversion_loader_name = get_conversion_loader_name(sae_info)
config_getter = NAMED_PRETRAINED_SAE_CONFIG_GETTERS[conversion_loader_name]
return config_getter(repo_id, folder_name=folder_name, options=options)
cfg = {
**config_getter(repo_id, folder_name=folder_name, options=options),
**cfg_overrides,
}
return handle_config_defaulting(cfg)


def dictionary_learning_sae_loader_1(
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/toolkit/test_pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sae_lens.sae import SAE
from sae_lens.toolkit.pretrained_sae_loaders import SAEConfigLoadOptions, get_sae_config


Expand All @@ -9,11 +10,15 @@ def test_get_sae_config_sae_lens():
)

expected_cfg_dict = {
"activation_fn_str": "relu",
"apply_b_dec_to_input": True,
"architecture": "standard",
"model_name": "gpt2-small",
"hook_point": "blocks.0.hook_resid_pre",
"hook_point_layer": 0,
"hook_point_head_index": None,
"dataset_path": "Skylion007/openwebtext",
"dataset_trust_remote_code": True,
"is_dataset_tokenized": False,
"context_size": 128,
"use_cached_activations": False,
Expand All @@ -32,9 +37,13 @@ def test_get_sae_config_sae_lens():
"lr": 0.0004,
"lr_scheduler_name": None,
"lr_warm_up_steps": 5000,
"model_from_pretrained_kwargs": {
"center_writing_weights": True,
},
"train_batch_size": 4096,
"use_ghost_grads": False,
"feature_sampling_window": 1000,
"finetuning_scaling_factor": False,
"feature_sampling_method": None,
"resample_batches": 1028,
"feature_reinit_scale": 0.2,
Expand All @@ -50,6 +59,10 @@ def test_get_sae_config_sae_lens():
"d_sae": 24576,
"tokens_per_buffer": 67108864,
"run_name": "24576-L1-8e-05-LR-0.0004-Tokens-3.000e+08",
"neuronpedia": None,
"normalize_activations": "none",
"prepend_bos": True,
"sae_lens_training_version": None,
}

assert cfg_dict == expected_cfg_dict
Expand Down Expand Up @@ -81,6 +94,7 @@ def test_get_sae_config_connor_rob_hook_z():
"context_size": 128,
"normalize_activations": "none",
"dataset_trust_remote_code": True,
"neuronpedia": None,
}

assert cfg_dict == expected_cfg_dict
Expand Down Expand Up @@ -111,6 +125,8 @@ def test_get_sae_config_gemma_2():
"dataset_trust_remote_code": True,
"apply_b_dec_to_input": False,
"normalize_activations": None,
"device": "cpu",
"neuronpedia": None,
}

assert cfg_dict == expected_cfg_dict
Expand Down Expand Up @@ -144,6 +160,22 @@ def test_get_sae_config_dictionary_learning_1():
"context_size": 128,
"normalize_activations": "none",
"neuronpedia_id": None,
"neuronpedia": None,
}

assert cfg_dict == expected_cfg_dict


def test_get_sae_config_matches_from_pretrained():
from_pretrained_cfg_dict = SAE.from_pretrained(
"gpt2-small-res-jb",
sae_id="blocks.0.hook_resid_pre",
device="cpu",
)[1]
direct_sae_cfg = get_sae_config(
"gpt2-small-res-jb",
sae_id="blocks.0.hook_resid_pre",
options=SAEConfigLoadOptions(device="cpu"),
)

assert direct_sae_cfg == from_pretrained_cfg_dict