-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Saving and loading activations from disk
- Loading branch information
1 parent
5f73918
commit 309e2de
Showing
5 changed files
with
222 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import math | ||
import os | ||
|
||
import torch | ||
from transformer_lens import HookedTransformer | ||
from tqdm import tqdm | ||
|
||
from sae_training.activations_store import ActivationsStore | ||
from sae_training.config import CacheActivationsRunnerConfig | ||
from sae_training.utils import shuffle_activations_pairwise | ||
|
||
|
||
def cache_activations_runner(cfg: CacheActivationsRunnerConfig): | ||
model = HookedTransformer.from_pretrained(cfg.model_name) | ||
model.to(cfg.device) | ||
activations_store = ActivationsStore(cfg, model, create_dataloader=False) | ||
|
||
# if the activations directory exists and has files in it, raise an exception | ||
if os.path.exists(activations_store.cfg.cached_activations_path): | ||
if len(os.listdir(activations_store.cfg.cached_activations_path)) > 0: | ||
raise Exception(f"Activations directory ({activations_store.cfg.cached_activations_path}) is not empty. Please delete it or specify a different path. Exiting the script to prevent accidental deletion of files.") | ||
else: | ||
os.makedirs(activations_store.cfg.cached_activations_path) | ||
|
||
print(f"Started caching {cfg.total_training_tokens} activations") | ||
tokens_per_buffer = cfg.store_batch_size * cfg.context_size * cfg.n_batches_in_buffer | ||
n_buffers = math.ceil(cfg.total_training_tokens / tokens_per_buffer) | ||
for i in tqdm(range(n_buffers), desc="Caching activations"): | ||
buffer = activations_store.get_buffer(cfg.n_batches_in_buffer) | ||
torch.save(buffer, f"{activations_store.cfg.cached_activations_path}/{i}.pt") | ||
del buffer | ||
|
||
if i % cfg.shuffle_every_n_buffers == 0 and i > 0: | ||
# Shuffle the buffers on disk | ||
|
||
# Do random pairwise shuffling between the last shuffle_every_n_buffers buffers | ||
for _ in range(cfg.n_shuffles_with_last_section): | ||
shuffle_activations_pairwise(activations_store.cfg.cached_activations_path, | ||
buffer_idx_range=(i - cfg.shuffle_every_n_buffers, i)) | ||
|
||
# Do more random pairwise shuffling between all the buffers | ||
for _ in range(cfg.n_shuffles_in_entire_dir): | ||
shuffle_activations_pairwise(activations_store.cfg.cached_activations_path, | ||
buffer_idx_range=(0, i)) | ||
|
||
# More final shuffling (mostly in case we didn't end on an i divisible by shuffle_every_n_buffers) | ||
for _ in tqdm(range(cfg.n_shuffles_final), desc="Final shuffling"): | ||
shuffle_activations_pairwise(activations_store.cfg.cached_activations_path, | ||
buffer_idx_range=(0, n_buffers)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
""" | ||
This is a util to time the execution of a function. | ||
(Has to be a separate file, if you put it in utils.py you get circular imports; need to find a permanent home for it) | ||
""" | ||
|
||
from functools import wraps | ||
import time | ||
|
||
def timeit(func): | ||
""" | ||
Decorator to time a function. | ||
Taken from https://dev.to/kcdchennai/python-decorator-to-measure-execution-time-54hk | ||
""" | ||
@wraps(func) | ||
def timeit_wrapper(*args, **kwargs): | ||
start_time = time.perf_counter() | ||
result = func(*args, **kwargs) | ||
end_time = time.perf_counter() | ||
total_time = end_time - start_time | ||
print(f'Function {func.__name__}{args} {kwargs} Took {total_time:.4f} seconds') | ||
return result | ||
return timeit_wrapper | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters