Skip to content

Commit

Permalink
add gzip for pt artefacts
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Dec 10, 2023
1 parent e90e54d commit 9614a23
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 16 deletions.
2 changes: 1 addition & 1 deletion sae_training/lm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def language_model_sae_runner(cfg):
)

# save sae to checkpoints folder
path = f"{cfg.checkpoint_path}/final_{sparse_autoencoder.get_name()}.pt"
path = f"{cfg.checkpoint_path}/final_{sparse_autoencoder.get_name()}.pkl.gz"
sparse_autoencoder.save_model(path)

# upload to wandb
Expand Down
45 changes: 32 additions & 13 deletions sae_training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
https://github.com/ArthurConmy/sae/blob/main/sae/model.py
"""

import gzip
import os
import pickle

import einops
import torch
Expand Down Expand Up @@ -215,13 +217,16 @@ def save_model(self, path: str):
"state_dict": self.state_dict()
}

torch.save(state_dict, path)
if path.endswith(".pt"):
torch.save(state_dict, path)
elif path.endswith("pkl.gz"):
with gzip.open(path, "wb") as f:
pickle.dump(state_dict, f)
else:
raise ValueError(f"Unexpected file extension: {path}, supported extensions are .pt and .pkl.gz")

print(f"Saved model to {path}")

def get_name(self):
sae_name = f"sparse_autoencoder_{self.cfg.model_name}_{self.cfg.hook_point}_{self.cfg.d_sae}"
return sae_name
print(f"Saved model to {path}")

@classmethod
def load_from_pretrained(cls, path: str):
Expand All @@ -235,14 +240,24 @@ def load_from_pretrained(cls, path: str):
raise FileNotFoundError(f"No file found at specified path: {path}")

# Load the state dictionary
try:
if torch.backends.mps.is_available():
state_dict = torch.load(path, map_location="mps")
state_dict["cfg"].device = "mps"
else:
state_dict = torch.load(path)
except Exception as e:
raise IOError(f"Error loading the state dictionary: {e}")
if path.endswith(".pt"):
try:
if torch.backends.mps.is_available():
state_dict = torch.load(path, map_location="mps")
state_dict["cfg"].device = "mps"
else:
state_dict = torch.load(path)
except Exception as e:
raise IOError(f"Error loading the state dictionary from .pt file: {e}")

elif path.endswith(".pkl.gz"):
try:
with gzip.open(path, 'rb') as f:
state_dict = pickle.load(f)
except Exception as e:
raise IOError(f"Error loading the state dictionary from .pkl.gz file: {e}")
else:
raise ValueError(f"Unexpected file extension: {path}, supported extensions are .pt and .pkl.gz")

# Ensure the loaded state contains both 'cfg' and 'state_dict'
if 'cfg' not in state_dict or 'state_dict' not in state_dict:
Expand All @@ -253,3 +268,7 @@ def load_from_pretrained(cls, path: str):
instance.load_state_dict(state_dict["state_dict"])

return instance

def get_name(self):
sae_name = f"sparse_autoencoder_{self.cfg.model_name}_{self.cfg.hook_point}_{self.cfg.d_sae}"
return sae_name
2 changes: 1 addition & 1 deletion sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def train_sae_on_language_model(
# checkpoint if at checkpoint frequency
if n_checkpoints > 0 and n_training_tokens > checkpoint_thresholds[0]:
cfg = sparse_autoencoder.cfg
path = f"{sparse_autoencoder.cfg.checkpoint_path}/{n_training_tokens}_{sparse_autoencoder.get_name()}.pt"
path = f"{sparse_autoencoder.cfg.checkpoint_path}/{n_training_tokens}_{sparse_autoencoder.get_name()}.pkl.gz"
sparse_autoencoder.save_model(path)
checkpoint_thresholds.pop(0)
if len(checkpoint_thresholds) == 0:
Expand Down
29 changes: 28 additions & 1 deletion tests/unit/test_sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_save_model(cfg):
state_dict_loaded["state_dict"][key]
)

def test_load_from_pretrained(cfg):
def test_load_from_pretrained_pt(cfg):

with tempfile.TemporaryDirectory() as tmpdirname:

Expand All @@ -126,6 +126,33 @@ def test_load_from_pretrained(cfg):
sparse_autoencoder_state_dict[key], # pylint: disable=unsubscriptable-object
sparse_autoencoder_loaded_state_dict[key] # pylint: disable=unsubscriptable-object
)

def test_load_from_pretrained_pkl_gz(cfg):

with tempfile.TemporaryDirectory() as tmpdirname:

# assert file does not exist
assert os.path.exists(tmpdirname + "/test.pkl.gz") == False

sparse_autoencoder = SparseAutoencoder(cfg)
sparse_autoencoder_state_dict = sparse_autoencoder.state_dict()
sparse_autoencoder.save_model(tmpdirname + "/test.pkl.gz")

assert os.path.exists(tmpdirname + "/test.pkl.gz")

sparse_autoencoder_loaded = SparseAutoencoder.load_from_pretrained(tmpdirname + "/test.pkl.gz")
sparse_autoencoder_loaded.cfg.device = "cpu" # might autoload onto mps
sparse_autoencoder_loaded = sparse_autoencoder_loaded.to("cpu")
sparse_autoencoder_loaded_state_dict = sparse_autoencoder_loaded.state_dict()
# check cfg matches the original
assert sparse_autoencoder_loaded.cfg == cfg

# check state_dict matches the original
for key in sparse_autoencoder.state_dict().keys():
assert torch.allclose(
sparse_autoencoder_state_dict[key], # pylint: disable=unsubscriptable-object
sparse_autoencoder_loaded_state_dict[key] # pylint: disable=unsubscriptable-object
)


def test_sparse_autoencoder_forward(sparse_autoencoder):
Expand Down

0 comments on commit 9614a23

Please sign in to comment.