diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index a34e970a..95b09cb7 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -191,6 +191,7 @@ def run(self): "model_id": self.model_id, "layer": str(self.layer), "sae_id": self.sae_id, + "log_sparsity": self.sparsity_threshold, "skipped_indexes": list(skipped_indexes), } ) @@ -237,6 +238,13 @@ def run(self): # print(f"Skipping batch - it's after end_batch: {feature_batch_count}") continue + output_file = f"{self.outputs_dir}/batch-{feature_batch_count}.json" + # if output_file exists, skip + if os.path.isfile(output_file): + logline = f"\n++++++++++ Skipping Batch #{feature_batch_count} output. File exists: {output_file} ++++++++++\n" + print(logline) + continue + print(f"========== Running Batch #{feature_batch_count} ==========") layout = SaeVisLayoutConfig( @@ -421,9 +429,11 @@ def run(self): json_object = json.dumps(to_write, cls=NpEncoder) with open( - f"{self.outputs_dir}/batch-{feature_batch_count}.json", + output_file, "w", ) as f: f.write(json_object) + logline = f"\n========== Completed Batch #{feature_batch_count} output: {output_file} ==========\n" + return diff --git a/tutorials/neuronpedia/neuronpedia.py b/tutorials/neuronpedia/neuronpedia.py index aa36a6b6..24d99865 100755 --- a/tutorials/neuronpedia/neuronpedia.py +++ b/tutorials/neuronpedia/neuronpedia.py @@ -21,6 +21,7 @@ from sae_lens.training.sparse_autoencoder import SparseAutoencoder OUTPUT_DIR_BASE = Path("../../neuronpedia_outputs") +RUN_SETTINGS_FILE = "run_settings.json" app = typer.Typer( add_completion=False, @@ -77,16 +78,6 @@ def generate( Enter value""", ), ] = 128, - resume_from_batch: Annotated[ - int, - typer.Option( - min=1, - help="Batch number to resume from.", - prompt=""" -Do you want to resume from a specific batch number? -Enter 1 to start from the beginning""", - ), - ] = 1, n_batches_to_sample: Annotated[ int, typer.Option( @@ -109,6 +100,16 @@ def generate( ), ] = 4096 * 6, + resume_from_batch: Annotated[ + int, + typer.Option( + min=1, + help="Batch number to resume from.", + prompt=""" +Do you want to resume from a specific batch number? +Enter 1 to start from the beginning. Existing batch files will not be overwritten.""", + ), + ] = 1, ): """ This will start a batch job that generates features for Neuronpedia for a specific SAE. To upload those features, use the 'upload' command afterwards. @@ -141,19 +142,67 @@ def generate( ) model_id = sparse_autoencoder.cfg.model_name + # make the outputs subdirectory if it doesn't exist, ensure it's not a file outputs_subdir = f"{model_id}_{sae_id}_{sparse_autoencoder.cfg.hook_point}" outputs_dir = OUTPUT_DIR_BASE.joinpath(outputs_subdir) if outputs_dir.exists() and outputs_dir.is_file(): print(f"Error: Output directory {outputs_dir.as_posix()} exists and is a file.") raise typer.Abort() outputs_dir.mkdir(parents=True, exist_ok=True) - # Check if output_dir has any files starting with "batch_" - batch_files = list(outputs_dir.glob("batch-*.json")) - if len(batch_files) > 0 and resume_from_batch == 1: - print( - f"Error: Output directory {outputs_dir.as_posix()} has existing batch files. This is only allowed if you are resuming from a batch. Please delete or move the existing batch-*.json files." - ) - raise typer.Abort() + + # Check if output_dir has a run_settings.json file. If so, load those settings. + run_settings_path = outputs_dir.joinpath(RUN_SETTINGS_FILE) + print("\n") + if run_settings_path.exists() and run_settings_path.is_file(): + # load the json file + with open(run_settings_path, "r") as f: + run_settings = json.load(f) + print( + f"[yellow]Found existing run_settings.json in {run_settings_path.as_posix()}, checking them." + ) + if run_settings["log_sparsity"] != log_sparsity: + print( + f"[red]Error: log_sparsity in {run_settings_path.as_posix()} doesn't match the current log_sparsity:\n{run_settings['log_sparsity']} vs {log_sparsity}" + ) + raise typer.Abort() + if run_settings["sae_id"] != sae_id: + print( + f"[red]Error: sae_id in {run_settings_path.as_posix()} doesn't match the current sae_id:\n{run_settings['sae_id']} vs {sae_id}" + ) + raise typer.Abort() + if run_settings["sae_path"] != sae_path_string: + print( + f"[red]Error: sae_path in {run_settings_path.as_posix()} doesn't match the current sae_path:\n{run_settings['sae_path']} vs {sae_path_string}" + ) + raise typer.Abort() + if run_settings["n_batches_to_sample"] != n_batches_to_sample: + print( + f"[red]Error: n_batches_to_sample in {run_settings_path.as_posix()} doesn't match the current n_batches_to_sample:\n{run_settings['n_batches_to_sample']} vs {n_batches_to_sample}" + ) + raise typer.Abort() + if run_settings["n_prompts_to_select"] != n_prompts_to_select: + print( + f"[red]Error: n_prompts_to_select in {run_settings_path.as_posix()} doesn't match the current n_prompts_to_select:\n{run_settings['n_prompts_to_select']} vs {n_prompts_to_select}" + ) + raise typer.Abort() + if run_settings["feat_per_batch"] != feat_per_batch: + print( + f"[red]Error: feat_per_batch in {run_settings_path.as_posix()} doesn't match the current feat_per_batch:\n{run_settings['feat_per_batch']} vs {feat_per_batch}" + ) + raise typer.Abort() + print("[green]All settings match, using existing run_settings.json.") + else: + print(f"[green]Creating run_settings.json in {run_settings_path.as_posix()}.") + run_settings = { + "sae_id": sae_id, + "sae_path": sae_path_string, + "log_sparsity": log_sparsity, + "n_batches_to_sample": n_batches_to_sample, + "n_prompts_to_select": n_prompts_to_select, + "feat_per_batch": feat_per_batch, + } + with open(run_settings_path, "w") as f: + json.dump(run_settings, f, indent=4) sparsity = load_sparsity(sae_path_string) # convert sparsity to logged sparsity if it's not @@ -290,7 +339,7 @@ def upload( dir_okay=True, readable=True, resolve_path=True, - prompt="What is the absolute, full local file path to the feature outputs directory?", + prompt="What is the absolute local file path to the feature outputs directory?", ), ], host: Annotated[ @@ -347,7 +396,7 @@ def upload_dead_stubs( dir_okay=True, readable=True, resolve_path=True, - prompt="What is the absolute, full local file path to the feature outputs directory?", + prompt="What is the absolute local file path to the feature outputs directory?", ), ], host: Annotated[