-
Notifications
You must be signed in to change notification settings - Fork 95
Open
Description
Hey,
Thanks for the great code :)
We are training SAEs with your code, and observe that sometimes the FVU during training (reported in wandb) is much lower than the one we measure offline, on samples from the train set, using this script:
import os
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from sparsify import Sae as sparsify_Sae
# --- CONFIG: fill these in / adjust as needed ---
INIT_REPO_ID = "/path/to/sae_root" # directory containing layers.X folders
MODEL_NAME = "Qwen/Qwen3-0.6B" # HF model name
BATCH_SIZE = 8 # just one batch
MAX_LENGTH = 256
N_EXAMPLES = 16 # first N examples from dataset
# ------------------------------------------------
device = "cuda"
with torch.inference_mode():
# ---- Find first SAE hookpoint ----
hookpoint_names = [
p.name for p in Path(INIT_REPO_ID).iterdir()
if p.is_dir() and p.name.startswith("layers.")
]
if not hookpoint_names:
raise ValueError(f"No 'layers.X' folders found in {INIT_REPO_ID}")
hookpoint_name = sorted(hookpoint_names)[0] # just pick the first one
sae_layer = int(hookpoint_name.split(".")[-1])
layer_index = sae_layer + 1 # as in original code (0 = embeddings)
print(f"Using SAE: {hookpoint_name} (layer_index={layer_index})")
# ---- Load model & tokenizer ----
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16
).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = (
tokenizer.bos_token or tokenizer.pad_token or tokenizer.eos_token
)
# ---- Load a tiny slice of data (just first batch) ----
dataset = load_dataset("EleutherAI/SmolLM2-135M-10B", split=f"train[:{N_EXAMPLES}]")
texts = dataset["text"]
batch_texts = [tokenizer.eos_token + t for t in texts[:BATCH_SIZE]]
enc = tokenizer(
batch_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_LENGTH,
)
inputs = {k: v.to(device) for k, v in enc.items()}
# ---- Forward pass to get hidden states ----
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states # tuple: (layer0, layer1, ...)
B, S = inputs["attention_mask"].shape
attention_mask_flat = inputs["attention_mask"].reshape(B * S)
# ---- Load SAE for this hookpoint ----
sae_path = os.path.join(INIT_REPO_ID, hookpoint_name)
sae = sparsify_Sae.load_from_disk(sae_path).to(device)
# ---- Get hidden states for that layer ----
hidden_state_all = hidden_states[layer_index].to(torch.bfloat16) # (B, S, D)
B_l, S_l, D_l = hidden_state_all.shape
hidden_flat = hidden_state_all.reshape(B_l * S_l, D_l)
# drop first token, mask pads
hidden_flat = hidden_flat[1:]
mask_flat = attention_mask_flat[1:]
valid_mask = mask_flat == 1
hidden_flat = hidden_flat[valid_mask]
if hidden_flat.numel() == 0:
raise RuntimeError("No valid (non-pad) tokens in this batch.")
# ---- SAE encode/decode ----
top_acts, top_indices, _ = sae.encode(hidden_flat)
decoded = sae.decode(top_acts=top_acts, top_indices=top_indices)
# ---- Convert to fp32 for stats ----
h = hidden_flat.to(torch.float32)
d = decoded.to(torch.float32)
# Reconstruction L2 loss (sum over tokens & dims)
diff = h - d
l2loss = (diff ** 2).sum().item()
# Total variance of hidden states (same formula as original Welford single-batch case)
mean = h.mean(dim=0) # (D,)
M2 = ((h - mean) ** 2).sum(dim=0) # (D,)
total_variance = M2.sum().item()
fvu = l2loss / total_variance
print(f"FVU on first batch for {hookpoint_name}: {fvu}")
We've had discrepancies sometimes where wandb FVU is 0.03 and measured one is greater than 1 ! Is there something wrong with the way we are using the SAEs after training ?
Thanks in advance,
Max
PS: this script measure on a single batch, but I checked and more samples do not "improve" the recovered FVU value.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels