Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,18 @@ def test_override_or_use_default_value():
assert utils.override_or_use_default_value(default_flag=False) == False


def test_is_library_available():
# Checking libraries that are definitely included in dependencies
assert utils.is_library_available("torch") is True
assert utils.is_library_available("numpy") is True

# Check the standard library (sys is always loaded)
assert utils.is_library_available("sys") is True

# Checking a library that obviously doesn't exist
assert utils.is_library_available("completely_fake_library_name_123") is False


class TestAttentionMask:
prompts = [
"Hello world!",
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,7 +2010,7 @@ def set_use_attn_result(self, use_attn_result: bool):

def set_use_split_qkv_input(self, use_split_qkv_input: bool):
"""
Toggles whether to allow editing of inputs to each attention head.
Toggles whether to allow editing of the separate Q, K, and V inputs to each attention head.
"""
self.cfg.use_split_qkv_input = use_split_qkv_input

Expand Down
9 changes: 8 additions & 1 deletion transformer_lens/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

import torch
import torch.optim as optim
import wandb
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

from transformer_lens import utils
from transformer_lens.HookedTransformer import HookedTransformer
from transformer_lens.utils import is_library_available


@dataclass
Expand Down Expand Up @@ -74,9 +74,16 @@ def train(
Returns:
The trained model
"""

torch.manual_seed(config.seed)
model.train()

if config.wandb:
if not is_library_available("wandb"):
raise ImportError("Wandb is not available")

import wandb

if config.wandb_project_name is None:
config.wandb_project_name = "easy-transformer"
wandb.init(project=config.wandb_project_name, config=vars(config))
Expand Down
11 changes: 11 additions & 0 deletions transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from __future__ import annotations

import collections.abc
import importlib.util
import inspect
import json
import os
import re
import shutil
import sys
from copy import deepcopy
from typing import Any, List, Optional, Tuple, Union, cast

Expand All @@ -34,6 +36,15 @@
USE_DEFAULT_VALUE = None


def is_library_available(name: str) -> bool:
"""
Checks if a library is installed in the current environment without importing it.
Prevents crash or segmentation fault.
"""

return name in sys.modules or importlib.util.find_spec(name) is not None


def select_compatible_kwargs(
kwargs_dict: dict[str, Any], callable: collections.abc.Callable
) -> dict[str, Any]:
Expand Down