diff --git a/sae_training/sae_group.py b/sae_training/sae_group.py index 51dc884d..568367f2 100644 --- a/sae_training/sae_group.py +++ b/sae_training/sae_group.py @@ -69,7 +69,7 @@ def load_from_pretrained(cls, path: str): try: if torch.backends.mps.is_available(): group = torch.load(path, map_location="mps") - group.cfg.device = "mps" + group["cfg"].device = "mps" else: group = torch.load(path) except Exception as e: