Skip to content

Commit

Permalink
fixing config typing
Browse files Browse the repository at this point in the history
switch to using explicit params for ActivationsStore config instead of RunnerConfig base class
  • Loading branch information
chanind committed Apr 5, 2024
1 parent 773bc02 commit 9be3445
Show file tree
Hide file tree
Showing 13 changed files with 370 additions and 186 deletions.
181 changes: 128 additions & 53 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import os
from typing import Any, Iterator, cast
from typing import Any, Iterator, Literal, TypeVar, cast

import torch
from datasets import (
Expand All @@ -12,6 +14,11 @@
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer

from sae_training.config import (
CacheActivationsRunnerConfig,
LanguageModelSAERunnerConfig,
)

HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset


Expand All @@ -21,75 +28,142 @@ class ActivationsStore:
while training SAEs.
"""

model: HookedTransformer
dataset: HfDataset
cached_activations_path: str | None
tokens_column: Literal["tokens", "input_ids", "text"]
hook_point_head_index: int | None

@classmethod
def from_config(
cls,
model: HookedTransformer,
cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig,
dataset: HfDataset | None = None,
create_dataloader: bool = True,
) -> "ActivationsStore":
cached_activations_path = cfg.cached_activations_path
# set cached_activations_path to None if we're not using cached activations
if (
isinstance(cfg, LanguageModelSAERunnerConfig)
and not cfg.use_cached_activations
):
cached_activations_path = None
return cls(
model=model,
dataset=dataset or cfg.dataset_path,
hook_point=cfg.hook_point,
hook_point_layers=listify(cfg.hook_point_layer),
hook_point_head_index=cfg.hook_point_head_index,
context_size=cfg.context_size,
d_in=cfg.d_in,
n_batches_in_buffer=cfg.n_batches_in_buffer,
total_training_tokens=cfg.total_training_tokens,
store_batch_size=cfg.store_batch_size,
train_batch_size=cfg.train_batch_size,
prepend_bos=cfg.prepend_bos,
device=cfg.device,
dtype=cfg.dtype,
cached_activations_path=cached_activations_path,
create_dataloader=create_dataloader,
)

def __init__(
self,
cfg: Any,
model: HookedTransformer,
dataset: HfDataset | None = None,
dataset: HfDataset | str,
hook_point: str,
hook_point_layers: list[int],
hook_point_head_index: int | None,
context_size: int,
d_in: int,
n_batches_in_buffer: int,
total_training_tokens: int,
store_batch_size: int,
train_batch_size: int,
prepend_bos: bool,
device: str | torch.device,
dtype: torch.dtype,
cached_activations_path: str | None = None,
create_dataloader: bool = True,
):
self.cfg = cfg
self.model = model
self.dataset = dataset or load_dataset(
cfg.dataset_path, split="train", streaming=True
self.dataset = (
load_dataset(dataset, split="train", streaming=True)
if isinstance(dataset, str)
else dataset
)
self.hook_point = hook_point
self.hook_point_layers = hook_point_layers
self.hook_point_head_index = hook_point_head_index
self.context_size = context_size
self.d_in = d_in
self.n_batches_in_buffer = n_batches_in_buffer
self.total_training_tokens = total_training_tokens
self.store_batch_size = store_batch_size
self.train_batch_size = train_batch_size
self.prepend_bos = prepend_bos
self.device = device
self.dtype = dtype
self.cached_activations_path = cached_activations_path

self.iterable_dataset = iter(self.dataset)

# Check if dataset is tokenized
dataset_sample = next(self.iterable_dataset)

# check if it's tokenized
if "tokens" in dataset_sample.keys():
self.cfg.is_dataset_tokenized = True
self.is_dataset_tokenized = True
self.tokens_column = "tokens"
elif "input_ids" in dataset_sample.keys():
self.cfg.is_dataset_tokenized = True
self.is_dataset_tokenized = True
self.tokens_column = "input_ids"
elif "text" in dataset_sample.keys():
self.cfg.is_dataset_tokenized = False
self.is_dataset_tokenized = False
self.tokens_column = "text"
else:
raise ValueError(
"Dataset must have a 'tokens', 'input_ids', or 'text' column."
)
self.iterable_dataset = iter(self.dataset) # Reset iterator after checking

if self.cfg.use_cached_activations: # EDIT: load from multi-layer acts
assert self.cfg.cached_activations_path is not None # keep pyright happy
if cached_activations_path is not None: # EDIT: load from multi-layer acts
assert self.cached_activations_path is not None # keep pyright happy

Check warning on line 132 in sae_training/activations_store.py

View check run for this annotation

Codecov / codecov/patch

sae_training/activations_store.py#L132

Added line #L132 was not covered by tests
# Sanity check: does the cache directory exist?
assert os.path.exists(
self.cfg.cached_activations_path
), f"Cache directory {self.cfg.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names."
self.cached_activations_path
), f"Cache directory {self.cached_activations_path} does not exist. Consider double-checking your dataset, model, and hook names."

self.next_cache_idx = 0 # which file to open next
self.next_idx_within_buffer = 0 # where to start reading from in that file

# Check that we have enough data on disk
first_buffer = torch.load(f"{self.cfg.cached_activations_path}/0.pt")
first_buffer = torch.load(f"{self.cached_activations_path}/0.pt")

Check warning on line 142 in sae_training/activations_store.py

View check run for this annotation

Codecov / codecov/patch

sae_training/activations_store.py#L142

Added line #L142 was not covered by tests
buffer_size_on_disk = first_buffer.shape[0]
n_buffers_on_disk = len(os.listdir(self.cfg.cached_activations_path))
n_buffers_on_disk = len(os.listdir(self.cached_activations_path))

Check warning on line 144 in sae_training/activations_store.py

View check run for this annotation

Codecov / codecov/patch

sae_training/activations_store.py#L144

Added line #L144 was not covered by tests
# Note: we're assuming all files have the same number of tokens
# (which seems reasonable imo since that's what our script does)
n_activations_on_disk = buffer_size_on_disk * n_buffers_on_disk
assert (
n_activations_on_disk > self.cfg.total_training_tokens
), f"Only {n_activations_on_disk/1e6:.1f}M activations on disk, but cfg.total_training_tokens is {self.cfg.total_training_tokens/1e6:.1f}M."
n_activations_on_disk > self.total_training_tokens
), f"Only {n_activations_on_disk/1e6:.1f}M activations on disk, but total_training_tokens is {self.total_training_tokens/1e6:.1f}M."

# TODO add support for "mixed loading" (ie use cache until you run out, then switch over to streaming from HF)

if create_dataloader:
# fill buffer half a buffer, so we can mix it with a new buffer
self.storage_buffer = self.get_buffer(self.cfg.n_batches_in_buffer // 2)
self.storage_buffer = self.get_buffer(self.n_batches_in_buffer // 2)
self.dataloader = self.get_data_loader()

def get_batch_tokens(self):
"""
Streams a batch of tokens from a dataset.
"""

batch_size = self.cfg.store_batch_size
context_size = self.cfg.context_size
device = self.cfg.device
batch_size = self.store_batch_size
context_size = self.context_size
device = self.device

batch_tokens = torch.zeros(
size=(0, context_size), device=device, dtype=torch.long, requires_grad=False
Expand Down Expand Up @@ -124,7 +198,7 @@ def get_batch_tokens(self):
token_len -= space_left

# only add BOS if it's not already the first token
if self.cfg.prepend_bos:
if self.prepend_bos:
bos_token_id_tensor = torch.tensor(
[self.model.tokenizer.bos_token_id],
device=tokens.device,
Expand Down Expand Up @@ -160,23 +234,19 @@ def get_activations(self, batch_tokens: torch.Tensor):
d_in may result from a concatenated head dimension.
"""
layers = (
self.cfg.hook_point_layer
if isinstance(self.cfg.hook_point_layer, list)
else [self.cfg.hook_point_layer]
)
act_names = [self.cfg.hook_point.format(layer=layer) for layer in layers]
layers = self.hook_point_layers
act_names = [self.hook_point.format(layer=layer) for layer in layers]
hook_point_max_layer = max(layers)
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
prepend_bos=self.cfg.prepend_bos,
prepend_bos=self.prepend_bos,
)[1]
activations_list = [layerwise_activations[act_name] for act_name in act_names]
if self.cfg.hook_point_head_index is not None:
if self.hook_point_head_index is not None:
activations_list = [
act[:, :, self.cfg.hook_point_head_index] for act in activations_list
act[:, :, self.hook_point_head_index] for act in activations_list
]
elif activations_list[0].ndim > 3: # if we have a head dimension
# flatten the head dimension
Expand All @@ -190,39 +260,35 @@ def get_activations(self, batch_tokens: torch.Tensor):
return stacked_activations

def get_buffer(self, n_batches_in_buffer: int):
context_size = self.cfg.context_size
batch_size = self.cfg.store_batch_size
d_in = self.cfg.d_in
context_size = self.context_size
batch_size = self.store_batch_size
d_in = self.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = (
len(self.cfg.hook_point_layer)
if isinstance(self.cfg.hook_point_layer, list)
else 1
) # Number of hook points or layers
num_layers = len(self.hook_point_layers) # Number of hook points or layers

if self.cfg.use_cached_activations:
if self.cached_activations_path is not None:
# Load the activations from disk
buffer_size = total_size * context_size
# Initialize an empty tensor with an additional dimension for layers
new_buffer = torch.zeros(
(buffer_size, num_layers, d_in),
dtype=self.cfg.dtype,
device=self.cfg.device,
dtype=self.dtype,
device=self.device,
)
n_tokens_filled = 0

# Assume activations for different layers are stored separately and need to be combined
while n_tokens_filled < buffer_size:
if not os.path.exists(
f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
f"{self.cached_activations_path}/{self.next_cache_idx}.pt"
):
print(
"\n\nWarning: Ran out of cached activation files earlier than expected."
)
print(
f"Expected to have {buffer_size} activations, but only found {n_tokens_filled}."
)
if buffer_size % self.cfg.total_training_tokens != 0:
if buffer_size % self.total_training_tokens != 0:
print(
"This might just be a rounding error — your batch_size * n_batches_in_buffer * context_size is not divisible by your total_training_tokens"
)
Expand All @@ -232,7 +298,7 @@ def get_buffer(self, n_batches_in_buffer: int):
return new_buffer

activations = torch.load(
f"{self.cfg.cached_activations_path}/{self.next_cache_idx}.pt"
f"{self.cached_activations_path}/{self.next_cache_idx}.pt"
)
taking_subset_of_file = False
if n_tokens_filled + activations.shape[0] > buffer_size:
Expand All @@ -257,8 +323,8 @@ def get_buffer(self, n_batches_in_buffer: int):
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
new_buffer = torch.zeros(
(total_size, context_size, num_layers, d_in),
dtype=self.cfg.dtype,
device=self.cfg.device,
dtype=self.dtype,
device=self.device,
)

for refill_batch_idx_start in refill_iterator:
Expand Down Expand Up @@ -286,11 +352,11 @@ def get_data_loader(
"""

batch_size = self.cfg.train_batch_size
batch_size = self.train_batch_size

# 1. # create new buffer by mixing stored and new buffer
mixing_buffer = torch.cat(
[self.get_buffer(self.cfg.n_batches_in_buffer // 2), self.storage_buffer],
[self.get_buffer(self.n_batches_in_buffer // 2), self.storage_buffer],
dim=0,
)

Expand Down Expand Up @@ -325,14 +391,14 @@ def next_batch(self):
return next(self.dataloader)

def _get_next_dataset_tokens(self) -> torch.Tensor:
device = self.cfg.device
if not self.cfg.is_dataset_tokenized:
device = self.device
if not self.is_dataset_tokenized:
s = next(self.iterable_dataset)[self.tokens_column]
tokens = self.model.to_tokens(
s,
truncate=True,
move_to_device=True,
prepend_bos=self.cfg.prepend_bos,
prepend_bos=self.prepend_bos,
).squeeze(0)
assert (
len(tokens.shape) == 1
Expand All @@ -345,8 +411,17 @@ def _get_next_dataset_tokens(self) -> torch.Tensor:
requires_grad=False,
)
if (
not self.cfg.prepend_bos
not self.prepend_bos
and tokens[0] == self.model.tokenizer.bos_token_id # type: ignore
):
tokens = tokens[1:]
return tokens


T = TypeVar("T")


def listify(x: T | list[T]) -> list[T]:
if isinstance(x, list):
return x

Check warning on line 426 in sae_training/activations_store.py

View check run for this annotation

Codecov / codecov/patch

sae_training/activations_store.py#L426

Added line #L426 was not covered by tests
return [x]
Loading

0 comments on commit 9be3445

Please sign in to comment.