Skip to content

Commit

Permalink
fix: Revert "feat: Add kl eval (#124)" (#127)
Browse files Browse the repository at this point in the history
This reverts commit c1d9cbe.
  • Loading branch information
jbloomAus authored May 7, 2024
1 parent 7264f99 commit 1a0619c
Showing 1 changed file with 7 additions and 29 deletions.
36 changes: 7 additions & 29 deletions sae_lens/training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def run_evals(
ntp_loss = losses_df["loss"].mean()
recons_loss = losses_df["recons_loss"].mean()
zero_abl_loss = losses_df["zero_abl_loss"].mean()
d_kl = losses_df["d_kl"].mean()

# get cache
_, cache = model.run_with_cache(
Expand Down Expand Up @@ -85,8 +84,6 @@ def run_evals(
f"metrics/ce_loss_without_sae{suffix}": ntp_loss,
f"metrics/ce_loss_with_sae{suffix}": recons_loss,
f"metrics/ce_loss_with_ablation{suffix}": zero_abl_loss,
# KL divergence against intact model
f"metrics/kl_div{suffix}": d_kl,
}

if wandb.run is not None:
Expand All @@ -107,7 +104,7 @@ def recons_loss_batched(
losses = []
for _ in range(n_batches):
batch_tokens = activation_store.get_batch_tokens()
score, loss, recons_loss, zero_abl_loss, d_kl = get_recons_loss(
score, loss, recons_loss, zero_abl_loss = get_recons_loss(
sparse_autoencoder, model, batch_tokens
)
losses.append(
Expand All @@ -116,13 +113,11 @@ def recons_loss_batched(
loss.mean().item(),
recons_loss.mean().item(),
zero_abl_loss.mean().item(),
d_kl.mean().item(),
)
)

losses = pd.DataFrame(
losses,
columns=cast(Any, ["score", "loss", "recons_loss", "zero_abl_loss", "d_kl"]),
losses, columns=cast(Any, ["score", "loss", "recons_loss", "zero_abl_loss"])
)

return losses
Expand All @@ -135,11 +130,10 @@ def get_recons_loss(
batch_tokens: torch.Tensor,
):
hook_point = sparse_autoencoder.cfg.hook_point
model_outs = model(
batch_tokens, return_type="both", **sparse_autoencoder.cfg.model_kwargs
loss = model(
batch_tokens, return_type="loss", **sparse_autoencoder.cfg.model_kwargs
)
head_index = sparse_autoencoder.cfg.hook_point_head_index
loss = model_outs.loss

def standard_replacement_hook(activations: torch.Tensor, hook: Any):
activations = sparse_autoencoder.forward(activations).sae_out.to(
Expand Down Expand Up @@ -172,13 +166,12 @@ def single_head_replacement_hook(activations: torch.Tensor, hook: Any):
else:
replacement_hook = standard_replacement_hook

recons_outs = model.run_with_hooks(
recons_loss = model.run_with_hooks(
batch_tokens,
return_type="both",
return_type="loss",
fwd_hooks=[(hook_point, partial(replacement_hook))],
**sparse_autoencoder.cfg.model_kwargs,
)
recons_loss = recons_outs.loss

zero_abl_loss = model.run_with_hooks(
batch_tokens,
Expand All @@ -192,22 +185,7 @@ def single_head_replacement_hook(activations: torch.Tensor, hook: Any):

score = (zero_abl_loss - recons_loss) / div_val

# KL divergence
model_logits = model_outs.logits # [batch, pos, d_vocab]
model_logprobs = torch.nn.functional.log_softmax(model_logits, dim=-1)
recons_logits = recons_outs.logits
recons_logprobs = torch.nn.functional.log_softmax(recons_logits, dim=-1)
# Note: PyTorch KL is backwards compared to the mathematical definition
# target distribution comes second, see
# https://pytorch.org/docs/stable/generated/torch.nn.functional.kl_div.html
d_kl = torch.nn.functional.kl_div(
recons_logprobs,
model_logprobs,
reduction="batchmean",
log_target=True, # for numerics
)

return score, loss, recons_loss, zero_abl_loss, d_kl
return score, loss, recons_loss, zero_abl_loss


def zero_ablate_hook(activations: torch.Tensor, hook: Any):
Expand Down

0 comments on commit 1a0619c

Please sign in to comment.