From 18fd4a1d43bbc8b7a490779f06e4f3c6983e3ea7 Mon Sep 17 00:00:00 2001 From: Lucy Farnik Date: Sun, 25 Feb 2024 20:16:39 -0800 Subject: [PATCH] Fixed edge case in activation cache shuffling --- sae_training/cache_activations_runner.py | 11 ++++++----- sae_training/utils.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/sae_training/cache_activations_runner.py b/sae_training/cache_activations_runner.py index cb4daae6..c1d2109f 100644 --- a/sae_training/cache_activations_runner.py +++ b/sae_training/cache_activations_runner.py @@ -52,8 +52,9 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig): ) # More final shuffling (mostly in case we didn't end on an i divisible by shuffle_every_n_buffers) - for _ in tqdm(range(cfg.n_shuffles_final), desc="Final shuffling"): - shuffle_activations_pairwise( - activations_store.cfg.cached_activations_path, - buffer_idx_range=(0, n_buffers), - ) + if n_buffers > 1: + for _ in tqdm(range(cfg.n_shuffles_final), desc="Final shuffling"): + shuffle_activations_pairwise( + activations_store.cfg.cached_activations_path, + buffer_idx_range=(0, n_buffers), + ) diff --git a/sae_training/utils.py b/sae_training/utils.py index c2633a56..c79f5c99 100644 --- a/sae_training/utils.py +++ b/sae_training/utils.py @@ -93,8 +93,8 @@ def shuffle_activations_pairwise(datapath: str, buffer_idx_range: Tuple[int, int Shuffles two buffers on disk. """ assert ( - buffer_idx_range[0] < buffer_idx_range[1] - ), "buffer_idx_range[0] must be smaller than buffer_idx_range[1]" + buffer_idx_range[0] < buffer_idx_range[1] - 1 + ), "buffer_idx_range[0] must be smaller than buffer_idx_range[1] by at least 1" buffer_idx1 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item() buffer_idx2 = torch.randint(buffer_idx_range[0], buffer_idx_range[1], (1,)).item()