Skip to content

Commit

Permalink
Fixed edge case in activation cache shuffling
Browse files Browse the repository at this point in the history
  • Loading branch information
lucyfarnik committed Feb 26, 2024
1 parent 37771ce commit 18fd4a1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
11 changes: 6 additions & 5 deletions sae_training/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig):
)

# More final shuffling (mostly in case we didn't end on an i divisible by shuffle_every_n_buffers)
for _ in tqdm(range(cfg.n_shuffles_final), desc="Final shuffling"):
shuffle_activations_pairwise(
activations_store.cfg.cached_activations_path,
buffer_idx_range=(0, n_buffers),
)
if n_buffers > 1:
for _ in tqdm(range(cfg.n_shuffles_final), desc="Final shuffling"):
shuffle_activations_pairwise(
activations_store.cfg.cached_activations_path,
buffer_idx_range=(0, n_buffers),
)
4 changes: 2 additions & 2 deletions sae_training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def shuffle_activations_pairwise(datapath: str, buffer_idx_range: Tuple[int, int
Shuffles two buffers on disk.
"""
assert (
buffer_idx_range[0] < buffer_idx_range[1]
), "buffer_idx_range[0] must be smaller than buffer_idx_range[1]"
buffer_idx_range[0] < buffer_idx_range[1] - 1
), "buffer_idx_range[0] must be smaller than buffer_idx_range[1] by at least 1"

buffer_idx1 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item()
buffer_idx2 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item()
Expand Down

0 comments on commit 18fd4a1

Please sign in to comment.