From b407aabfcd9d8ff9a8752a92a26701dbc8da04a2 Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Thu, 30 Nov 2023 14:06:01 +0000 Subject: [PATCH] fix old test, may remove --- tests/acceptance/test_train_sae_toy_models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/acceptance/test_train_sae_toy_models.py b/tests/acceptance/test_train_sae_toy_models.py index 5b835822..1dcaca3a 100644 --- a/tests/acceptance/test_train_sae_toy_models.py +++ b/tests/acceptance/test_train_sae_toy_models.py @@ -1,12 +1,12 @@ +import einops import pytest import torch -import einops -import wandb +import wandb from sae_training.SAE import SAE -from sae_training.train_sae import train_sae from sae_training.toy_models import Config as ToyConfig from sae_training.toy_models import Model as ToyModel +from sae_training.train_sae import train_sae @pytest.fixture @@ -41,7 +41,7 @@ def test_train_sae_toy_models(model): sae = SAE(toy_config) # wandb.init(project="sae-training-test", config=toy_config) - sae = train_sae(sae, hidden.detach().squeeze(), use_wandb=False, l1_coeff=0.001, batch_size=32, n_epochs=10) + sae = train_sae(model, sae, hidden.detach().squeeze(), use_wandb=False, l1_coeff=0.001, batch_size=32, n_epochs=10) # wandb.finish()