Skip to content

Commit

Permalink
Merge pull request #100 from jbloomAus/np_improvements
Browse files Browse the repository at this point in the history
Improvements to Neuronpedia Runner
  • Loading branch information
jbloomAus authored Apr 23, 2024
2 parents a5e6850 + 4b5412b commit 5118f7f
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 20 deletions.
12 changes: 11 additions & 1 deletion sae_lens/analysis/neuronpedia_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def run(self):
"model_id": self.model_id,
"layer": str(self.layer),
"sae_id": self.sae_id,
"log_sparsity": self.sparsity_threshold,
"skipped_indexes": list(skipped_indexes),
}
)
Expand Down Expand Up @@ -237,6 +238,13 @@ def run(self):
# print(f"Skipping batch - it's after end_batch: {feature_batch_count}")
continue

output_file = f"{self.outputs_dir}/batch-{feature_batch_count}.json"
# if output_file exists, skip
if os.path.isfile(output_file):
logline = f"\n++++++++++ Skipping Batch #{feature_batch_count} output. File exists: {output_file} ++++++++++\n"
print(logline)
continue

print(f"========== Running Batch #{feature_batch_count} ==========")

layout = SaeVisLayoutConfig(
Expand Down Expand Up @@ -421,9 +429,11 @@ def run(self):
json_object = json.dumps(to_write, cls=NpEncoder)

with open(
f"{self.outputs_dir}/batch-{feature_batch_count}.json",
output_file,
"w",
) as f:
f.write(json_object)

logline = f"\n========== Completed Batch #{feature_batch_count} output: {output_file} ==========\n"

return
87 changes: 68 additions & 19 deletions tutorials/neuronpedia/neuronpedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sae_lens.training.sparse_autoencoder import SparseAutoencoder

OUTPUT_DIR_BASE = Path("../../neuronpedia_outputs")
RUN_SETTINGS_FILE = "run_settings.json"

app = typer.Typer(
add_completion=False,
Expand Down Expand Up @@ -77,16 +78,6 @@ def generate(
Enter value""",
),
] = 128,
resume_from_batch: Annotated[
int,
typer.Option(
min=1,
help="Batch number to resume from.",
prompt="""
Do you want to resume from a specific batch number?
Enter 1 to start from the beginning""",
),
] = 1,
n_batches_to_sample: Annotated[
int,
typer.Option(
Expand All @@ -109,6 +100,16 @@ def generate(
),
] = 4096
* 6,
resume_from_batch: Annotated[
int,
typer.Option(
min=1,
help="Batch number to resume from.",
prompt="""
Do you want to resume from a specific batch number?
Enter 1 to start from the beginning. Existing batch files will not be overwritten.""",
),
] = 1,
):
"""
This will start a batch job that generates features for Neuronpedia for a specific SAE. To upload those features, use the 'upload' command afterwards.
Expand Down Expand Up @@ -141,19 +142,67 @@ def generate(
)
model_id = sparse_autoencoder.cfg.model_name

# make the outputs subdirectory if it doesn't exist, ensure it's not a file
outputs_subdir = f"{model_id}_{sae_id}_{sparse_autoencoder.cfg.hook_point}"
outputs_dir = OUTPUT_DIR_BASE.joinpath(outputs_subdir)
if outputs_dir.exists() and outputs_dir.is_file():
print(f"Error: Output directory {outputs_dir.as_posix()} exists and is a file.")
raise typer.Abort()
outputs_dir.mkdir(parents=True, exist_ok=True)
# Check if output_dir has any files starting with "batch_"
batch_files = list(outputs_dir.glob("batch-*.json"))
if len(batch_files) > 0 and resume_from_batch == 1:
print(
f"Error: Output directory {outputs_dir.as_posix()} has existing batch files. This is only allowed if you are resuming from a batch. Please delete or move the existing batch-*.json files."
)
raise typer.Abort()

# Check if output_dir has a run_settings.json file. If so, load those settings.
run_settings_path = outputs_dir.joinpath(RUN_SETTINGS_FILE)
print("\n")
if run_settings_path.exists() and run_settings_path.is_file():
# load the json file
with open(run_settings_path, "r") as f:
run_settings = json.load(f)
print(
f"[yellow]Found existing run_settings.json in {run_settings_path.as_posix()}, checking them."
)
if run_settings["log_sparsity"] != log_sparsity:
print(
f"[red]Error: log_sparsity in {run_settings_path.as_posix()} doesn't match the current log_sparsity:\n{run_settings['log_sparsity']} vs {log_sparsity}"
)
raise typer.Abort()
if run_settings["sae_id"] != sae_id:
print(
f"[red]Error: sae_id in {run_settings_path.as_posix()} doesn't match the current sae_id:\n{run_settings['sae_id']} vs {sae_id}"
)
raise typer.Abort()
if run_settings["sae_path"] != sae_path_string:
print(
f"[red]Error: sae_path in {run_settings_path.as_posix()} doesn't match the current sae_path:\n{run_settings['sae_path']} vs {sae_path_string}"
)
raise typer.Abort()
if run_settings["n_batches_to_sample"] != n_batches_to_sample:
print(
f"[red]Error: n_batches_to_sample in {run_settings_path.as_posix()} doesn't match the current n_batches_to_sample:\n{run_settings['n_batches_to_sample']} vs {n_batches_to_sample}"
)
raise typer.Abort()
if run_settings["n_prompts_to_select"] != n_prompts_to_select:
print(
f"[red]Error: n_prompts_to_select in {run_settings_path.as_posix()} doesn't match the current n_prompts_to_select:\n{run_settings['n_prompts_to_select']} vs {n_prompts_to_select}"
)
raise typer.Abort()
if run_settings["feat_per_batch"] != feat_per_batch:
print(
f"[red]Error: feat_per_batch in {run_settings_path.as_posix()} doesn't match the current feat_per_batch:\n{run_settings['feat_per_batch']} vs {feat_per_batch}"
)
raise typer.Abort()
print("[green]All settings match, using existing run_settings.json.")
else:
print(f"[green]Creating run_settings.json in {run_settings_path.as_posix()}.")
run_settings = {
"sae_id": sae_id,
"sae_path": sae_path_string,
"log_sparsity": log_sparsity,
"n_batches_to_sample": n_batches_to_sample,
"n_prompts_to_select": n_prompts_to_select,
"feat_per_batch": feat_per_batch,
}
with open(run_settings_path, "w") as f:
json.dump(run_settings, f, indent=4)

sparsity = load_sparsity(sae_path_string)
# convert sparsity to logged sparsity if it's not
Expand Down Expand Up @@ -290,7 +339,7 @@ def upload(
dir_okay=True,
readable=True,
resolve_path=True,
prompt="What is the absolute, full local file path to the feature outputs directory?",
prompt="What is the absolute local file path to the feature outputs directory?",
),
],
host: Annotated[
Expand Down Expand Up @@ -347,7 +396,7 @@ def upload_dead_stubs(
dir_okay=True,
readable=True,
resolve_path=True,
prompt="What is the absolute, full local file path to the feature outputs directory?",
prompt="What is the absolute local file path to the feature outputs directory?",
),
],
host: Annotated[
Expand Down

0 comments on commit 5118f7f

Please sign in to comment.