Skip to content

Commit

Permalink
Merge pull request #71 from weissercn/main
Browse files Browse the repository at this point in the history
Addressing notebook issues
  • Loading branch information
jbloomAus authored Apr 8, 2024
2 parents 4d7d1e7 + 1db0b5a commit 8417505
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
7 changes: 6 additions & 1 deletion sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformer_lens.hook_points import HookedRootModule, HookPoint

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.utils import BackwardsCompatiblePickleClass


class ForwardOutput(NamedTuple):
Expand Down Expand Up @@ -241,7 +242,11 @@ def load_from_pretrained(cls, path: str):
if path.endswith(".pt"):
try:
if torch.backends.mps.is_available():
state_dict = torch.load(path, map_location="mps")
state_dict = torch.load(
path,
map_location="mps",
pickle_module=BackwardsCompatiblePickleClass,
)
state_dict["cfg"].device = "mps"
else:
state_dict = torch.load(path)
Expand Down
4 changes: 4 additions & 0 deletions sae_lens/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def find_class(self, module: str, name: str):
return super().find_class(module, name)


class BackwardsCompatiblePickleClass:
Unpickler = BackwardsCompatibleUnpickler


def shuffle_activations_pairwise(datapath: str, buffer_idx_range: Tuple[int, int]):
"""
Shuffles two buffers on disk.
Expand Down
6 changes: 4 additions & 2 deletions tutorials/logits_lens_with_features.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@
"metadata": {},
"outputs": [],
"source": [
"import nltk\n",
"nltk.download('averaged_perceptron_tagger')\n",
"# get the vocab we need to filter to formulate token sets.\n",
"vocab = model.tokenizer.get_vocab() # type: ignore\n",
"\n",
Expand Down Expand Up @@ -608,7 +610,7 @@
"metadata": {},
"outputs": [],
"source": [
"for category in [\"starts_with_space\"]:\n",
"for category in [\"boys_names\"]:\n",
" plot_top_k_feature_projections_by_token_and_category(\n",
" token_set_selected,\n",
" df_enrichment_scores,\n",
Expand Down Expand Up @@ -655,7 +657,7 @@
"\n",
"fig = px.area(\n",
" tmp_df,\n",
" title=\"Kurtosis by Layer\",\n",
" title=\"Skewness by Layer\",\n",
" width=800,\n",
" height=600,\n",
" color_discrete_sequence=px.colors.sequential.Turbo,\n",
Expand Down

0 comments on commit 8417505

Please sign in to comment.