Skip to content

Commit

Permalink
Safe to_str_tokens, fix memory issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Johnny Lin authored and Johnny Lin committed Mar 30, 2024
1 parent 85d8f57 commit 901b888
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 350 deletions.
70 changes: 58 additions & 12 deletions sae_analysis/neuronpedia_runner.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import os
from typing import Any, Optional, cast
from typing import Any, Dict, List, Optional, Union, cast

# set TOKENIZERS_PARALLELISM to false to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import time

import torch
from sae_vis.data_fetching_fns import get_feature_data
from sae_vis.data_storing_fns import FeatureVisParams, to_str_tokens
from sae_vis.data_storing_fns import FeatureVisParams
from tqdm import tqdm

import numpy as np
Expand All @@ -16,6 +16,8 @@

from matplotlib import colors

OUT_OF_RANGE_TOKEN = "<|outofrange|>"

BG_COLOR_MAP = colors.LinearSegmentedColormap.from_list(
"bg_color_map", ["white", "darkorange"]
)
Expand Down Expand Up @@ -47,8 +49,9 @@ def __init__(
n_features_at_a_time: int = 1024,
buffer_tokens_left: int = 8,
buffer_tokens_right: int = 8,
# start_batch
start_batch: int = 0,
# start and end batch
start_batch_inclusive: int = 0,
end_batch_inclusive: Optional[int] = None,
):
self.sae_path = sae_path
if init_session:
Expand All @@ -60,7 +63,8 @@ def __init__(
self.buffer_tokens_right = buffer_tokens_right
self.n_batches_to_sample_from = n_batches_to_sample_from
self.n_prompts_to_select = n_prompts_to_select
self.start_batch = start_batch
self.start_batch = start_batch_inclusive
self.end_batch = end_batch_inclusive

# Deal with file structure
if not os.path.exists(neuronpedia_parent_folder):
Expand Down Expand Up @@ -107,6 +111,32 @@ def get_tokens(
def round_list(self, to_round: list[float]):
return list(np.round(to_round, 3))

def to_str_tokens_safe(
self, vocab_dict: Dict[int, str], tokens: Union[int, List[int], torch.Tensor]
):
"""
does to_str_tokens, except handles out of range
"""
vocab_max_index = self.model.cfg.d_vocab - 1
# Deal with the int case separately
if isinstance(tokens, int):
if tokens > vocab_max_index:
return OUT_OF_RANGE_TOKEN
return vocab_dict[tokens]

# If the tokens are a (possibly nested) list, turn them into a tensor
if isinstance(tokens, list):
tokens = torch.tensor(tokens)

# Get flattened list of tokens
str_tokens = [
(vocab_dict[t] if t <= vocab_max_index else OUT_OF_RANGE_TOKEN)
for t in tokens.flatten().tolist()
]

# Reshape
return np.reshape(str_tokens, tokens.shape).tolist()

def run(self):
"""
Generate the Neuronpedia outputs.
Expand Down Expand Up @@ -137,6 +167,16 @@ def run(self):
feature_idx = np.array_split(feature_idx, n_subarrays)
feature_idx = [x.tolist() for x in feature_idx]

print(f"==== Starting at batch: {self.start_batch}")
if self.end_batch is not None:
print(f"==== Ending at batch: {self.end_batch}")

if self.start_batch > len(feature_idx) + 1:
print(
f"Start batch {self.start_batch} is greater than number of batches + 1 {len(feature_idx)}, exiting"
)
exit()

# write dead into file so we can create them as dead in Neuronpedia
skipped_indexes = set(range(self.n_features)) - set(self.target_feature_indexes)
skipped_indexes_json = json.dumps({"skipped_indexes": list(skipped_indexes)})
Expand Down Expand Up @@ -166,16 +206,20 @@ def run(self):
}
# pad with blank tokens to the actual vocab size
for i in range(len(vocab_dict), self.model.cfg.d_vocab):
vocab_dict[i] = " "
vocab_dict[i] = OUT_OF_RANGE_TOKEN

with torch.no_grad():
feature_batch_count = 0
for features_to_process in tqdm(feature_idx):
feature_batch_count = feature_batch_count + 1

if feature_batch_count < self.start_batch:
print(f"Skipping batch: {feature_batch_count}")
# print(f"Skipping batch - it's after start_batch: {feature_batch_count}")
continue
if self.end_batch is not None and feature_batch_count > self.end_batch:
# print(f"Skipping batch - it's after end_batch: {feature_batch_count}")
continue

print(f"Doing batch: {feature_batch_count}")

feature_vis_params = FeatureVisParams(
Expand Down Expand Up @@ -255,11 +299,11 @@ def run(self):
# feature.left_tables_data.correlated_features_pearson
# )

feature_output["neg_str"] = to_str_tokens(
feature_output["neg_str"] = self.to_str_tokens_safe(
vocab_dict, feature.middle_plots_data.bottom10_token_ids
)
feature_output["neg_values"] = bottom10_logits
feature_output["pos_str"] = to_str_tokens(
feature_output["pos_str"] = self.to_str_tokens_safe(
vocab_dict, feature.middle_plots_data.top10_token_ids
)
feature_output["pos_values"] = top10_logits
Expand Down Expand Up @@ -320,11 +364,13 @@ def run(self):
negContribs = []
for i in range(len(sd.token_ids)):
strs.append(
to_str_tokens(vocab_dict, sd.token_ids[i])
self.to_str_tokens_safe(
vocab_dict, sd.token_ids[i]
)
)
posContrib = {}
posTokens = [
to_str_tokens(vocab_dict, j)
self.to_str_tokens_safe(vocab_dict, j)
for j in sd.top5_token_ids[i]
]
if len(posTokens) > 0:
Expand All @@ -335,7 +381,7 @@ def run(self):
posContribs.append(posContrib)
negContrib = {}
negTokens = [
to_str_tokens(vocab_dict, j)
self.to_str_tokens_safe(vocab_dict, j)
for j in sd.bottom5_token_ids[i]
]
if len(negTokens) > 0:
Expand Down
Loading

0 comments on commit 901b888

Please sign in to comment.