Skip to content

FVU discrepancies training/eval #130

@maxime-louis

Description

@maxime-louis

Hey,
Thanks for the great code :)

We are training SAEs with your code, and observe that sometimes the FVU during training (reported in wandb) is much lower than the one we measure offline, on samples from the train set, using this script:

import os
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from sparsify import Sae as sparsify_Sae

# --- CONFIG: fill these in / adjust as needed ---
INIT_REPO_ID = "/path/to/sae_root"          # directory containing layers.X folders
MODEL_NAME = "Qwen/Qwen3-0.6B"              # HF model name
BATCH_SIZE = 8                              # just one batch
MAX_LENGTH = 256
N_EXAMPLES = 16                             # first N examples from dataset
# ------------------------------------------------

device = "cuda"

with torch.inference_mode():
    # ---- Find first SAE hookpoint ----
    hookpoint_names = [
        p.name for p in Path(INIT_REPO_ID).iterdir()
        if p.is_dir() and p.name.startswith("layers.")
    ]
    if not hookpoint_names:
        raise ValueError(f"No 'layers.X' folders found in {INIT_REPO_ID}")

    hookpoint_name = sorted(hookpoint_names)[0]   # just pick the first one
    sae_layer = int(hookpoint_name.split(".")[-1])
    layer_index = sae_layer + 1  # as in original code (0 = embeddings)

    print(f"Using SAE: {hookpoint_name} (layer_index={layer_index})")

    # ---- Load model & tokenizer ----
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16
    ).to(device)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenizer.pad_token = (
        tokenizer.bos_token or tokenizer.pad_token or tokenizer.eos_token
    )

    # ---- Load a tiny slice of data (just first batch) ----
    dataset = load_dataset("EleutherAI/SmolLM2-135M-10B", split=f"train[:{N_EXAMPLES}]")
    texts = dataset["text"]

    batch_texts = [tokenizer.eos_token + t for t in texts[:BATCH_SIZE]]

    enc = tokenizer(
        batch_texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=MAX_LENGTH,
    )

    inputs = {k: v.to(device) for k, v in enc.items()}

    # ---- Forward pass to get hidden states ----
    outputs = model(**inputs, output_hidden_states=True)
    hidden_states = outputs.hidden_states  # tuple: (layer0, layer1, ...)

    B, S = inputs["attention_mask"].shape
    attention_mask_flat = inputs["attention_mask"].reshape(B * S)

    # ---- Load SAE for this hookpoint ----
    sae_path = os.path.join(INIT_REPO_ID, hookpoint_name)
    sae = sparsify_Sae.load_from_disk(sae_path).to(device)

    # ---- Get hidden states for that layer ----
    hidden_state_all = hidden_states[layer_index].to(torch.bfloat16)  # (B, S, D)
    B_l, S_l, D_l = hidden_state_all.shape

    hidden_flat = hidden_state_all.reshape(B_l * S_l, D_l)

    # drop first token, mask pads
    hidden_flat = hidden_flat[1:]
    mask_flat = attention_mask_flat[1:]
    valid_mask = mask_flat == 1
    hidden_flat = hidden_flat[valid_mask]

    if hidden_flat.numel() == 0:
        raise RuntimeError("No valid (non-pad) tokens in this batch.")

    # ---- SAE encode/decode ----
    top_acts, top_indices, _ = sae.encode(hidden_flat)
    decoded = sae.decode(top_acts=top_acts, top_indices=top_indices)

    # ---- Convert to fp32 for stats ----
    h = hidden_flat.to(torch.float32)
    d = decoded.to(torch.float32)

    # Reconstruction L2 loss (sum over tokens & dims)
    diff = h - d
    l2loss = (diff ** 2).sum().item()

    # Total variance of hidden states (same formula as original Welford single-batch case)
    mean = h.mean(dim=0)                     # (D,)
    M2 = ((h - mean) ** 2).sum(dim=0)        # (D,)
    total_variance = M2.sum().item()

    fvu = l2loss / total_variance
    print(f"FVU on first batch for {hookpoint_name}: {fvu}")

We've had discrepancies sometimes where wandb FVU is 0.03 and measured one is greater than 1 ! Is there something wrong with the way we are using the SAEs after training ?

Thanks in advance,
Max

PS: this script measure on a single batch, but I checked and more samples do not "improve" the recovered FVU value.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions